1 /* Copyright 2017 The TensorFlow Authors All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/profiler/rpc/client/capture_profile.h"
16
17 #include "grpcpp/grpcpp.h"
18
19 #include <cstdio>
20 #include <ctime>
21 #include <vector>
22
23 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/lib/io/path.h"
27 #include "tensorflow/core/lib/strings/numbers.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29 #include "tensorflow/core/platform/grpc_services.h"
30 #include "tensorflow/core/profiler/rpc/client/dump_tpu_profile.h"
31 #include "tensorflow/core/util/events_writer.h"
32
33 namespace tensorflow {
34 namespace profiler {
35 namespace client {
36
37 constexpr uint64 kMaxEvents = 1000000;
38
GetCurrentTimeStampAsString()39 string GetCurrentTimeStampAsString() {
40 char s[128];
41 std::time_t t = std::time(nullptr);
42 auto result = std::strftime(s, sizeof(s), "%F_%T", std::localtime(&t));
43 DCHECK_NE(result, 0);
44 return s;
45 }
46
ValidateHostPortPair(const string & host_port)47 Status ValidateHostPortPair(const string& host_port) {
48 uint32 port;
49 std::vector<string> parts = str_util::Split(host_port, ':');
50 // Must be host:port, port must be a number, host must not contain a '/',
51 // host also must not be empty.
52 if (parts.size() != 2 || !strings::safe_strtou32(parts[1], &port) ||
53 parts[0].find("/") != string::npos || parts[0].empty()) {
54 return errors::InvalidArgument("Could not interpret \"", host_port,
55 "\" as a host-port pair.");
56 }
57 return Status::OK();
58 }
59
PopulateProfileRequest(int duration_ms,const string & repository_root,const string & session_id,const ProfileOptions & opts)60 ProfileRequest PopulateProfileRequest(int duration_ms,
61 const string& repository_root,
62 const string& session_id,
63 const ProfileOptions& opts) {
64 ProfileRequest request;
65 request.set_duration_ms(duration_ms);
66 request.set_max_events(kMaxEvents);
67 if (tensorflow::str_util::StartsWith(repository_root, "gs://")) {
68 // For backward compatibilities, only generate tracetable etc when the
69 // user provide a GCS path for model directory.
70 request.set_repository_root(repository_root);
71 request.set_session_id(session_id);
72 }
73 request.add_tools("op_profile");
74 request.add_tools("input_pipeline");
75 request.add_tools("memory_viewer");
76 request.add_tools("overview_page");
77 *request.mutable_opts() = opts;
78 return request;
79 }
80
81 // Returns whether the returned trace is empty.
82 // Failure are handled by CHECK, i.e. abort()
Profile(const string & service_addr,const string & logdir,int duration_ms,const string & repository_root,const string & session_id,const ProfileOptions & opts)83 Status Profile(const string& service_addr, const string& logdir,
84 int duration_ms, const string& repository_root,
85 const string& session_id, const ProfileOptions& opts) {
86 ProfileRequest request =
87 PopulateProfileRequest(duration_ms, repository_root, session_id, opts);
88
89 ::grpc::ClientContext context;
90 ::grpc::ChannelArguments channel_args;
91 // TODO(qiuminxu): use `NewHostPortGrpcChannel` instead once their
92 // `ValidateHostPortPair` checks for empty host string case.
93 channel_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH,
94 std::numeric_limits<int32>::max());
95 std::unique_ptr<grpc::ProfilerService::Stub> stub =
96 grpc::ProfilerService::NewStub(::grpc::CreateCustomChannel(
97 "dns:///" + service_addr, ::grpc::InsecureChannelCredentials(),
98 channel_args));
99 ProfileResponse response;
100 TF_RETURN_IF_ERROR(
101 FromGrpcStatus(stub->Profile(&context, request, &response)));
102
103 if (!response.encoded_trace().empty()) {
104 TF_CHECK_OK(WriteTensorboardTPUProfile(logdir, session_id, "", response,
105 &std::cout));
106 // Print this at the end so that it's not buried in irrelevant LOG messages.
107 std::cout
108 << "NOTE: using the trace duration " << duration_ms << "ms."
109 << std::endl
110 << "Set an appropriate duration (with --duration_ms) if you "
111 "don't see a full step in your trace or the captured trace is too "
112 "large."
113 << std::endl;
114 }
115
116 if (response.encoded_trace().empty()) {
117 return Status(tensorflow::error::Code::UNAVAILABLE,
118 "No trace event is collected");
119 }
120 return Status::OK();
121 }
122
123 // Start a new profiling session that include all the hosts included in
124 // hostnames, for the time interval of duration_ms. Possibly save the profiling
125 // result in the directory specified by repository_root and session_id.
NewSession(const string & service_addr,const std::vector<tensorflow::string> & hostnames,int duration_ms,const string & repository_root,const string & session_id,const ProfileOptions & opts)126 Status NewSession(const string& service_addr,
127 const std::vector<tensorflow::string>& hostnames,
128 int duration_ms, const string& repository_root,
129 const string& session_id, const ProfileOptions& opts) {
130 NewProfileSessionRequest new_session_request;
131 *new_session_request.mutable_request() =
132 PopulateProfileRequest(duration_ms, repository_root, session_id, opts);
133 new_session_request.set_repository_root(repository_root);
134 new_session_request.set_session_id(session_id);
135 for (const auto& hostname : hostnames) {
136 new_session_request.add_hosts(hostname);
137 }
138
139 ::grpc::ClientContext context;
140 ::grpc::ChannelArguments channel_args;
141 // TODO(qiuminxu): use `NewHostPortGrpcChannel` instead once their
142 // `ValidateHostPortPair` checks for empty host string case.
143 channel_args.SetMaxReceiveMessageSize(std::numeric_limits<int32>::max());
144 // TODO(jiesun): GRPC support following relevant naming scheme:
145 // 1. dns:///host:port
146 // 2. ipv4:host:port or ipv6:[host]:port
147 // We might need to change the prefix which depends on what TPU name resolver
148 // will give us.
149 std::unique_ptr<grpc::ProfileAnalysis::Stub> stub =
150 grpc::ProfileAnalysis::NewStub(::grpc::CreateCustomChannel(
151 "dns:///" + service_addr, ::grpc::InsecureChannelCredentials(),
152 channel_args));
153 NewProfileSessionResponse new_session_response;
154 TF_RETURN_IF_ERROR(FromGrpcStatus(
155 stub->NewSession(&context, new_session_request, &new_session_response)));
156
157 std::cout << "Profile session succeed for host(s):"
158 << str_util::Join(hostnames, ",") << std::endl;
159 if (new_session_response.empty_trace()) {
160 return Status(tensorflow::error::Code::UNAVAILABLE,
161 "No trace event is collected");
162 }
163 return Status::OK();
164 }
165
166 // Creates an empty event file if not already exists, which indicates that we
167 // have a plugins/profile/ directory in the current logdir.
MaybeCreateEmptyEventFile(const tensorflow::string & logdir)168 Status MaybeCreateEmptyEventFile(const tensorflow::string& logdir) {
169 // Suffix for an empty event file. it should be kept in sync with
170 // _EVENT_FILE_SUFFIX in tensorflow/python/eager/profiler.py.
171 constexpr char kProfileEmptySuffix[] = ".profile-empty";
172 std::vector<string> children;
173 TF_RETURN_IF_ERROR(Env::Default()->GetChildren(logdir, &children));
174 for (const string& child : children) {
175 if (str_util::EndsWith(child, kProfileEmptySuffix)) {
176 return Status::OK();
177 }
178 }
179 EventsWriter event_writer(io::JoinPath(logdir, "events"));
180 return event_writer.InitWithSuffix(kProfileEmptySuffix);
181 }
182
183 // Starts tracing on a single or multiple TPU hosts and saves the result in the
184 // given logdir. If no trace was collected, retries tracing for
185 // num_tracing_attempts.
StartTracing(const tensorflow::string & service_addr,const tensorflow::string & logdir,const tensorflow::string & workers_list,bool include_dataset_ops,int duration_ms,int num_tracing_attempts)186 Status StartTracing(const tensorflow::string& service_addr,
187 const tensorflow::string& logdir,
188 const tensorflow::string& workers_list,
189 bool include_dataset_ops, int duration_ms,
190 int num_tracing_attempts) {
191 // Use the current timestamp as the run name.
192 tensorflow::string session_id = GetCurrentTimeStampAsString();
193 constexpr char kProfilePluginDirectory[] = "plugins/profile/";
194 tensorflow::string repository_root =
195 io::JoinPath(logdir, kProfilePluginDirectory);
196 std::vector<tensorflow::string> hostnames =
197 tensorflow::str_util::Split(workers_list, ",");
198
199 TF_RETURN_IF_ERROR(MaybeCreateEmptyEventFile(logdir));
200
201 Status status = Status::OK();
202 int remaining_attempts = num_tracing_attempts;
203 tensorflow::ProfileOptions opts;
204 opts.set_include_dataset_ops(include_dataset_ops);
205 while (true) {
206 std::cout << "Starting to profile TPU traces for " << duration_ms << " ms. "
207 << "Remaining attempt(s): " << remaining_attempts-- << std::endl;
208 if (hostnames.empty()) {
209 status = Profile(service_addr, logdir, duration_ms, repository_root,
210 session_id, opts);
211 } else {
212 tensorflow::string tpu_master = service_addr;
213 status = NewSession(tpu_master, hostnames, duration_ms, repository_root,
214 session_id, opts);
215 }
216 if (remaining_attempts <= 0 || status.ok() ||
217 status.code() != tensorflow::error::Code::UNAVAILABLE)
218 break;
219 std::cout << "No trace event is collected. Automatically retrying."
220 << std::endl
221 << std::endl;
222 }
223
224 if (status.code() == tensorflow::error::Code::UNAVAILABLE) {
225 std::cout << "No trace event is collected after " << num_tracing_attempts
226 << " attempt(s). "
227 << "Perhaps, you want to try again (with more attempts?)."
228 << std::endl
229 << "Tip: increase number of attempts with --num_tracing_attempts."
230 << std::endl;
231 }
232 return status;
233 }
234
PopulateMonitorRequest(int duration_ms,int monitoring_level)235 MonitorRequest PopulateMonitorRequest(int duration_ms, int monitoring_level) {
236 MonitorRequest request;
237 request.set_duration_ms(duration_ms);
238 request.set_monitoring_level(monitoring_level);
239 return request;
240 }
241
242 // Repeatedly collects profiles and shows user-friendly metrics for
243 // 'num_queries' time(s).
StartMonitoring(const tensorflow::string & service_addr,int duration_ms,int monitoring_level,int num_queries)244 void StartMonitoring(const tensorflow::string& service_addr, int duration_ms,
245 int monitoring_level, int num_queries) {
246 for (int query = 0; query < num_queries; ++query) {
247 MonitorRequest request =
248 PopulateMonitorRequest(duration_ms, monitoring_level);
249
250 ::grpc::ClientContext context;
251 ::grpc::ChannelArguments channel_args;
252 channel_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH,
253 std::numeric_limits<int32>::max());
254 std::unique_ptr<grpc::ProfilerService::Stub> stub =
255 grpc::ProfilerService::NewStub(::grpc::CreateCustomChannel(
256 "dns:///" + service_addr, ::grpc::InsecureChannelCredentials(),
257 channel_args));
258 MonitorResponse response;
259 TF_QCHECK_OK(FromGrpcStatus(stub->Monitor(&context, request, &response)));
260
261 std::cout << "Cloud TPU Monitoring Results (Sample " << query + 1
262 << "):\n\n"
263 << response.data() << std::flush;
264 }
265 }
266
267 } // namespace client
268 } // namespace profiler
269 } // namespace tensorflow
270