• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/profiler/rpc/client/capture_profile.h"
16 
17 #include "grpcpp/grpcpp.h"
18 
19 #include <cstdio>
20 #include <ctime>
21 #include <vector>
22 
23 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/lib/io/path.h"
27 #include "tensorflow/core/lib/strings/numbers.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29 #include "tensorflow/core/platform/grpc_services.h"
30 #include "tensorflow/core/profiler/rpc/client/dump_tpu_profile.h"
31 #include "tensorflow/core/util/events_writer.h"
32 
33 namespace tensorflow {
34 namespace profiler {
35 namespace client {
36 
37 constexpr uint64 kMaxEvents = 1000000;
38 
GetCurrentTimeStampAsString()39 string GetCurrentTimeStampAsString() {
40   char s[128];
41   std::time_t t = std::time(nullptr);
42   auto result = std::strftime(s, sizeof(s), "%F_%T", std::localtime(&t));
43   DCHECK_NE(result, 0);
44   return s;
45 }
46 
ValidateHostPortPair(const string & host_port)47 Status ValidateHostPortPair(const string& host_port) {
48   uint32 port;
49   std::vector<string> parts = str_util::Split(host_port, ':');
50   // Must be host:port, port must be a number, host must not contain a '/',
51   // host also must not be empty.
52   if (parts.size() != 2 || !strings::safe_strtou32(parts[1], &port) ||
53       parts[0].find("/") != string::npos || parts[0].empty()) {
54     return errors::InvalidArgument("Could not interpret \"", host_port,
55                                    "\" as a host-port pair.");
56   }
57   return Status::OK();
58 }
59 
PopulateProfileRequest(int duration_ms,const string & repository_root,const string & session_id,const ProfileOptions & opts)60 ProfileRequest PopulateProfileRequest(int duration_ms,
61                                       const string& repository_root,
62                                       const string& session_id,
63                                       const ProfileOptions& opts) {
64   ProfileRequest request;
65   request.set_duration_ms(duration_ms);
66   request.set_max_events(kMaxEvents);
67   if (tensorflow::str_util::StartsWith(repository_root, "gs://")) {
68     // For backward compatibilities, only generate tracetable etc when the
69     // user provide a GCS path for model directory.
70     request.set_repository_root(repository_root);
71     request.set_session_id(session_id);
72   }
73   request.add_tools("op_profile");
74   request.add_tools("input_pipeline");
75   request.add_tools("memory_viewer");
76   request.add_tools("overview_page");
77   *request.mutable_opts() = opts;
78   return request;
79 }
80 
81 // Returns whether the returned trace is empty.
82 // Failure are handled by CHECK, i.e. abort()
Profile(const string & service_addr,const string & logdir,int duration_ms,const string & repository_root,const string & session_id,const ProfileOptions & opts)83 Status Profile(const string& service_addr, const string& logdir,
84                int duration_ms, const string& repository_root,
85                const string& session_id, const ProfileOptions& opts) {
86   ProfileRequest request =
87       PopulateProfileRequest(duration_ms, repository_root, session_id, opts);
88 
89   ::grpc::ClientContext context;
90   ::grpc::ChannelArguments channel_args;
91   // TODO(qiuminxu): use `NewHostPortGrpcChannel` instead once their
92   // `ValidateHostPortPair` checks for empty host string case.
93   channel_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH,
94                       std::numeric_limits<int32>::max());
95   std::unique_ptr<grpc::ProfilerService::Stub> stub =
96       grpc::ProfilerService::NewStub(::grpc::CreateCustomChannel(
97           "dns:///" + service_addr, ::grpc::InsecureChannelCredentials(),
98           channel_args));
99   ProfileResponse response;
100   TF_RETURN_IF_ERROR(
101       FromGrpcStatus(stub->Profile(&context, request, &response)));
102 
103   if (!response.encoded_trace().empty()) {
104     TF_CHECK_OK(WriteTensorboardTPUProfile(logdir, session_id, "", response,
105                                            &std::cout));
106     // Print this at the end so that it's not buried in irrelevant LOG messages.
107     std::cout
108         << "NOTE: using the trace duration " << duration_ms << "ms."
109         << std::endl
110         << "Set an appropriate duration (with --duration_ms) if you "
111            "don't see a full step in your trace or the captured trace is too "
112            "large."
113         << std::endl;
114   }
115 
116   if (response.encoded_trace().empty()) {
117     return Status(tensorflow::error::Code::UNAVAILABLE,
118                   "No trace event is collected");
119   }
120   return Status::OK();
121 }
122 
123 // Start a new profiling session that include all the hosts included in
124 // hostnames, for the time interval of duration_ms. Possibly save the profiling
125 // result in the directory specified by repository_root and session_id.
NewSession(const string & service_addr,const std::vector<tensorflow::string> & hostnames,int duration_ms,const string & repository_root,const string & session_id,const ProfileOptions & opts)126 Status NewSession(const string& service_addr,
127                   const std::vector<tensorflow::string>& hostnames,
128                   int duration_ms, const string& repository_root,
129                   const string& session_id, const ProfileOptions& opts) {
130   NewProfileSessionRequest new_session_request;
131   *new_session_request.mutable_request() =
132       PopulateProfileRequest(duration_ms, repository_root, session_id, opts);
133   new_session_request.set_repository_root(repository_root);
134   new_session_request.set_session_id(session_id);
135   for (const auto& hostname : hostnames) {
136     new_session_request.add_hosts(hostname);
137   }
138 
139   ::grpc::ClientContext context;
140   ::grpc::ChannelArguments channel_args;
141   // TODO(qiuminxu): use `NewHostPortGrpcChannel` instead once their
142   // `ValidateHostPortPair` checks for empty host string case.
143   channel_args.SetMaxReceiveMessageSize(std::numeric_limits<int32>::max());
144   // TODO(jiesun): GRPC support following relevant naming scheme:
145   // 1. dns:///host:port
146   // 2. ipv4:host:port or ipv6:[host]:port
147   // We might need to change the prefix which depends on what TPU name resolver
148   // will give us.
149   std::unique_ptr<grpc::ProfileAnalysis::Stub> stub =
150       grpc::ProfileAnalysis::NewStub(::grpc::CreateCustomChannel(
151           "dns:///" + service_addr, ::grpc::InsecureChannelCredentials(),
152           channel_args));
153   NewProfileSessionResponse new_session_response;
154   TF_RETURN_IF_ERROR(FromGrpcStatus(
155       stub->NewSession(&context, new_session_request, &new_session_response)));
156 
157   std::cout << "Profile session succeed for host(s):"
158             << str_util::Join(hostnames, ",") << std::endl;
159   if (new_session_response.empty_trace()) {
160     return Status(tensorflow::error::Code::UNAVAILABLE,
161                   "No trace event is collected");
162   }
163   return Status::OK();
164 }
165 
166 // Creates an empty event file if not already exists, which indicates that we
167 // have a plugins/profile/ directory in the current logdir.
MaybeCreateEmptyEventFile(const tensorflow::string & logdir)168 Status MaybeCreateEmptyEventFile(const tensorflow::string& logdir) {
169   // Suffix for an empty event file.  it should be kept in sync with
170   // _EVENT_FILE_SUFFIX in tensorflow/python/eager/profiler.py.
171   constexpr char kProfileEmptySuffix[] = ".profile-empty";
172   std::vector<string> children;
173   TF_RETURN_IF_ERROR(Env::Default()->GetChildren(logdir, &children));
174   for (const string& child : children) {
175     if (str_util::EndsWith(child, kProfileEmptySuffix)) {
176       return Status::OK();
177     }
178   }
179   EventsWriter event_writer(io::JoinPath(logdir, "events"));
180   return event_writer.InitWithSuffix(kProfileEmptySuffix);
181 }
182 
183 // Starts tracing on a single or multiple TPU hosts and saves the result in the
184 // given logdir. If no trace was collected, retries tracing for
185 // num_tracing_attempts.
StartTracing(const tensorflow::string & service_addr,const tensorflow::string & logdir,const tensorflow::string & workers_list,bool include_dataset_ops,int duration_ms,int num_tracing_attempts)186 Status StartTracing(const tensorflow::string& service_addr,
187                     const tensorflow::string& logdir,
188                     const tensorflow::string& workers_list,
189                     bool include_dataset_ops, int duration_ms,
190                     int num_tracing_attempts) {
191   // Use the current timestamp as the run name.
192   tensorflow::string session_id = GetCurrentTimeStampAsString();
193   constexpr char kProfilePluginDirectory[] = "plugins/profile/";
194   tensorflow::string repository_root =
195       io::JoinPath(logdir, kProfilePluginDirectory);
196   std::vector<tensorflow::string> hostnames =
197       tensorflow::str_util::Split(workers_list, ",");
198 
199   TF_RETURN_IF_ERROR(MaybeCreateEmptyEventFile(logdir));
200 
201   Status status = Status::OK();
202   int remaining_attempts = num_tracing_attempts;
203   tensorflow::ProfileOptions opts;
204   opts.set_include_dataset_ops(include_dataset_ops);
205   while (true) {
206     std::cout << "Starting to profile TPU traces for " << duration_ms << " ms. "
207               << "Remaining attempt(s): " << remaining_attempts-- << std::endl;
208     if (hostnames.empty()) {
209       status = Profile(service_addr, logdir, duration_ms, repository_root,
210                        session_id, opts);
211     } else {
212       tensorflow::string tpu_master = service_addr;
213       status = NewSession(tpu_master, hostnames, duration_ms, repository_root,
214                           session_id, opts);
215     }
216     if (remaining_attempts <= 0 || status.ok() ||
217         status.code() != tensorflow::error::Code::UNAVAILABLE)
218       break;
219     std::cout << "No trace event is collected. Automatically retrying."
220               << std::endl
221               << std::endl;
222   }
223 
224   if (status.code() == tensorflow::error::Code::UNAVAILABLE) {
225     std::cout << "No trace event is collected after " << num_tracing_attempts
226               << " attempt(s). "
227               << "Perhaps, you want to try again (with more attempts?)."
228               << std::endl
229               << "Tip: increase number of attempts with --num_tracing_attempts."
230               << std::endl;
231   }
232   return status;
233 }
234 
PopulateMonitorRequest(int duration_ms,int monitoring_level)235 MonitorRequest PopulateMonitorRequest(int duration_ms, int monitoring_level) {
236   MonitorRequest request;
237   request.set_duration_ms(duration_ms);
238   request.set_monitoring_level(monitoring_level);
239   return request;
240 }
241 
242 // Repeatedly collects profiles and shows user-friendly metrics for
243 // 'num_queries' time(s).
StartMonitoring(const tensorflow::string & service_addr,int duration_ms,int monitoring_level,int num_queries)244 void StartMonitoring(const tensorflow::string& service_addr, int duration_ms,
245                      int monitoring_level, int num_queries) {
246   for (int query = 0; query < num_queries; ++query) {
247     MonitorRequest request =
248         PopulateMonitorRequest(duration_ms, monitoring_level);
249 
250     ::grpc::ClientContext context;
251     ::grpc::ChannelArguments channel_args;
252     channel_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH,
253                         std::numeric_limits<int32>::max());
254     std::unique_ptr<grpc::ProfilerService::Stub> stub =
255         grpc::ProfilerService::NewStub(::grpc::CreateCustomChannel(
256             "dns:///" + service_addr, ::grpc::InsecureChannelCredentials(),
257             channel_args));
258     MonitorResponse response;
259     TF_QCHECK_OK(FromGrpcStatus(stub->Monitor(&context, request, &response)));
260 
261     std::cout << "Cloud TPU Monitoring Results (Sample " << query + 1
262               << "):\n\n"
263               << response.data() << std::flush;
264   }
265 }
266 
267 }  // namespace client
268 }  // namespace profiler
269 }  // namespace tensorflow
270