1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/debug/debug_io_utils.h"
17
18 #include <stddef.h>
19 #include <string.h>
20 #include <cmath>
21 #include <cstdlib>
22 #include <cstring>
23 #include <limits>
24 #include <utility>
25 #include <vector>
26
27 #ifndef PLATFORM_WINDOWS
28 #include "grpcpp/create_channel.h"
29 #else
30 // winsock2.h is used in grpc, so Ws2_32.lib is needed
31 #pragma comment(lib, "Ws2_32.lib")
32 #endif // #ifndef PLATFORM_WINDOWS
33
34 #include "tensorflow/core/debug/debug_callback_registry.h"
35 #include "tensorflow/core/debug/debugger_event_metadata.pb.h"
36 #include "tensorflow/core/framework/graph.pb.h"
37 #include "tensorflow/core/framework/summary.pb.h"
38 #include "tensorflow/core/framework/tensor.pb.h"
39 #include "tensorflow/core/framework/tensor_shape.pb.h"
40 #include "tensorflow/core/lib/core/bits.h"
41 #include "tensorflow/core/lib/hash/hash.h"
42 #include "tensorflow/core/lib/io/path.h"
43 #include "tensorflow/core/lib/strings/str_util.h"
44 #include "tensorflow/core/lib/strings/stringprintf.h"
45 #include "tensorflow/core/platform/protobuf.h"
46 #include "tensorflow/core/util/event.pb.h"
47
48 #define GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR \
49 return errors::Unimplemented( \
50 kGrpcURLScheme, " debug URL scheme is not implemented on Windows yet.")
51
52 namespace tensorflow {
53
54 namespace {
55
56 // Creates an Event proto representing a chunk of a Tensor. This method only
57 // populates the field of the Event proto that represent the envelope
58 // information (e.g., timestamp, device_name, num_chunks, chunk_index, dtype,
59 // shape). It does not set the value.tensor field, which should be set by the
60 // caller separately.
PrepareChunkEventProto(const DebugNodeKey & debug_node_key,const uint64 wall_time_us,const size_t num_chunks,const size_t chunk_index,const DataType & tensor_dtype,const TensorShapeProto & tensor_shape)61 Event PrepareChunkEventProto(const DebugNodeKey& debug_node_key,
62 const uint64 wall_time_us, const size_t num_chunks,
63 const size_t chunk_index,
64 const DataType& tensor_dtype,
65 const TensorShapeProto& tensor_shape) {
66 Event event;
67 event.set_wall_time(static_cast<double>(wall_time_us));
68 Summary::Value* value = event.mutable_summary()->add_value();
69
70 // Create the debug node_name in the Summary proto.
71 // For example, if tensor_name = "foo/node_a:0", and the debug_op is
72 // "DebugIdentity", the debug node_name in the Summary proto will be
73 // "foo/node_a:0:DebugIdentity".
74 value->set_node_name(debug_node_key.debug_node_name);
75
76 // Tag by the node name. This allows TensorBoard to quickly fetch data
77 // per op.
78 value->set_tag(debug_node_key.node_name);
79
80 // Store data within debugger metadata to be stored for each event.
81 third_party::tensorflow::core::debug::DebuggerEventMetadata metadata;
82 metadata.set_device(debug_node_key.device_name);
83 metadata.set_output_slot(debug_node_key.output_slot);
84 metadata.set_num_chunks(num_chunks);
85 metadata.set_chunk_index(chunk_index);
86
87 // Encode the data in JSON.
88 string json_output;
89 tensorflow::protobuf::util::JsonPrintOptions json_options;
90 json_options.always_print_primitive_fields = true;
91 auto status = tensorflow::protobuf::util::MessageToJsonString(
92 metadata, &json_output, json_options);
93 if (status.ok()) {
94 // Store summary metadata. Set the plugin to use this data as "debugger".
95 SummaryMetadata::PluginData* plugin_data =
96 value->mutable_metadata()->mutable_plugin_data();
97 plugin_data->set_plugin_name(DebugIO::kDebuggerPluginName);
98 plugin_data->set_content(json_output);
99 } else {
100 LOG(WARNING) << "Failed to convert DebuggerEventMetadata proto to JSON. "
101 << "The debug_node_name is " << debug_node_key.debug_node_name
102 << ".";
103 }
104
105 value->mutable_tensor()->set_dtype(tensor_dtype);
106 *value->mutable_tensor()->mutable_tensor_shape() = tensor_shape;
107
108 return event;
109 }
110
111 // Translates the length of a string to number of bytes when the string is
112 // encoded as bytes in protobuf. Note that this makes a conservative estimate
113 // (i.e., an estimate that is usually too large, but never too small under the
114 // gRPC message size limit) of the Varint-encoded length, to workaround the lack
115 // of a portable length function.
StringValMaxBytesInProto(const string & str)116 const size_t StringValMaxBytesInProto(const string& str) {
117 #if defined(PLATFORM_GOOGLE)
118 return str.size() + DebugGrpcIO::kGrpcMaxVarintLengthSize;
119 #else
120 return str.size();
121 #endif
122 }
123
124 // Breaks a string Tensor (represented as a TensorProto) as a vector of Event
125 // protos.
WrapStringTensorAsEvents(const DebugNodeKey & debug_node_key,const uint64 wall_time_us,const size_t chunk_size_limit,TensorProto * tensor_proto,std::vector<Event> * events)126 Status WrapStringTensorAsEvents(const DebugNodeKey& debug_node_key,
127 const uint64 wall_time_us,
128 const size_t chunk_size_limit,
129 TensorProto* tensor_proto,
130 std::vector<Event>* events) {
131 const protobuf::RepeatedPtrField<string>& strs = tensor_proto->string_val();
132 const size_t num_strs = strs.size();
133 const size_t chunk_size_ub = chunk_size_limit > 0
134 ? chunk_size_limit
135 : std::numeric_limits<size_t>::max();
136
137 // E.g., if cutoffs is {j, k, l}, the chunks will have index ranges:
138 // [0:a), [a:b), [c:<end>].
139 std::vector<size_t> cutoffs;
140 size_t chunk_size = 0;
141 for (size_t i = 0; i < num_strs; ++i) {
142 // Take into account the extra bytes in proto buffer.
143 if (StringValMaxBytesInProto(strs[i]) > chunk_size_ub) {
144 return errors::FailedPrecondition(
145 "string value at index ", i, " from debug node ",
146 debug_node_key.debug_node_name,
147 " does not fit gRPC message size limit (", chunk_size_ub, ")");
148 }
149 if (chunk_size + StringValMaxBytesInProto(strs[i]) > chunk_size_ub) {
150 cutoffs.push_back(i);
151 chunk_size = 0;
152 }
153 chunk_size += StringValMaxBytesInProto(strs[i]);
154 }
155 cutoffs.push_back(num_strs);
156 const size_t num_chunks = cutoffs.size();
157
158 for (size_t i = 0; i < num_chunks; ++i) {
159 Event event = PrepareChunkEventProto(debug_node_key, wall_time_us,
160 num_chunks, i, tensor_proto->dtype(),
161 tensor_proto->tensor_shape());
162 Summary::Value* value = event.mutable_summary()->mutable_value(0);
163
164 if (cutoffs.size() == 1) {
165 value->mutable_tensor()->mutable_string_val()->Swap(
166 tensor_proto->mutable_string_val());
167 } else {
168 const size_t begin = (i == 0) ? 0 : cutoffs[i - 1];
169 const size_t end = cutoffs[i];
170 for (size_t j = begin; j < end; ++j) {
171 value->mutable_tensor()->add_string_val(strs[j]);
172 }
173 }
174
175 events->push_back(std::move(event));
176 }
177
178 return Status::OK();
179 }
180
181 // Encapsulates the tensor value inside a vector of Event protos. Large tensors
182 // are broken up to multiple protos to fit the chunk_size_limit. In each Event
183 // proto the field summary.tensor carries the content of the tensor.
184 // If chunk_size_limit <= 0, the tensor will not be broken into chunks, i.e., a
185 // length-1 vector will be returned, regardless of the size of the tensor.
WrapTensorAsEvents(const DebugNodeKey & debug_node_key,const Tensor & tensor,const uint64 wall_time_us,const size_t chunk_size_limit,std::vector<Event> * events)186 Status WrapTensorAsEvents(const DebugNodeKey& debug_node_key,
187 const Tensor& tensor, const uint64 wall_time_us,
188 const size_t chunk_size_limit,
189 std::vector<Event>* events) {
190 TensorProto tensor_proto;
191 if (tensor.dtype() == DT_STRING) {
192 // Treat DT_STRING specially, so that tensor_util.MakeNdarray in Python can
193 // convert the TensorProto to string-type numpy array. MakeNdarray does not
194 // work with strings encoded by AsProtoTensorContent() in tensor_content.
195 tensor.AsProtoField(&tensor_proto);
196
197 TF_RETURN_IF_ERROR(WrapStringTensorAsEvents(
198 debug_node_key, wall_time_us, chunk_size_limit, &tensor_proto, events));
199 } else {
200 tensor.AsProtoTensorContent(&tensor_proto);
201
202 const size_t total_length = tensor_proto.tensor_content().size();
203 const size_t chunk_size_ub =
204 chunk_size_limit > 0 ? chunk_size_limit : total_length;
205 const size_t num_chunks =
206 (total_length == 0)
207 ? 1
208 : (total_length + chunk_size_ub - 1) / chunk_size_ub;
209 for (size_t i = 0; i < num_chunks; ++i) {
210 const size_t pos = i * chunk_size_ub;
211 const size_t len =
212 (i == num_chunks - 1) ? (total_length - pos) : chunk_size_ub;
213 Event event = PrepareChunkEventProto(debug_node_key, wall_time_us,
214 num_chunks, i, tensor_proto.dtype(),
215 tensor_proto.tensor_shape());
216 event.mutable_summary()
217 ->mutable_value(0)
218 ->mutable_tensor()
219 ->set_tensor_content(tensor_proto.tensor_content().substr(pos, len));
220 events->push_back(std::move(event));
221 }
222 }
223
224 return Status::OK();
225 }
226
227 // Appends an underscore and a timestamp to a file path. If the path already
228 // exists on the file system, append a hyphen and a 1-up index. Consecutive
229 // values of the index will be tried until the first unused one is found.
230 // TOCTOU race condition is not of concern here due to the fact that tfdbg
231 // sets parallel_iterations attribute of all while_loops to 1 to prevent
232 // the same node from between executed multiple times concurrently.
AppendTimestampToFilePath(const string & in,const uint64 timestamp)233 string AppendTimestampToFilePath(const string& in, const uint64 timestamp) {
234 string out = strings::StrCat(in, "_", timestamp);
235
236 uint64 i = 1;
237 while (Env::Default()->FileExists(out).ok()) {
238 out = strings::StrCat(in, "_", timestamp, "-", i);
239 ++i;
240 }
241 return out;
242 }
243
244 #ifndef PLATFORM_WINDOWS
245 // Publishes encoded GraphDef through a gRPC debugger stream, in chunks,
246 // conforming to the gRPC message size limit.
PublishEncodedGraphDefInChunks(const string & encoded_graph_def,const string & device_name,const int64 wall_time,const string & debug_url)247 Status PublishEncodedGraphDefInChunks(const string& encoded_graph_def,
248 const string& device_name,
249 const int64 wall_time,
250 const string& debug_url) {
251 const uint64 hash = ::tensorflow::Hash64(encoded_graph_def);
252 const size_t total_length = encoded_graph_def.size();
253 const size_t num_chunks =
254 static_cast<size_t>(std::ceil(static_cast<float>(total_length) /
255 DebugGrpcIO::kGrpcMessageSizeLimitBytes));
256 for (size_t i = 0; i < num_chunks; ++i) {
257 const size_t pos = i * DebugGrpcIO::kGrpcMessageSizeLimitBytes;
258 const size_t len = (i == num_chunks - 1)
259 ? (total_length - pos)
260 : DebugGrpcIO::kGrpcMessageSizeLimitBytes;
261 Event event;
262 event.set_wall_time(static_cast<double>(wall_time));
263 // Prefix the chunk with
264 // <hash64>,<device_name>,<wall_time>|<index>|<num_chunks>|.
265 // TODO(cais): Use DebuggerEventMetadata to store device_name, num_chunks
266 // and chunk_index, instead.
267 event.set_graph_def(strings::StrCat(hash, ",", device_name, ",", wall_time,
268 "|", i, "|", num_chunks, "|",
269 encoded_graph_def.substr(pos, len)));
270 const Status s = DebugGrpcIO::SendEventProtoThroughGrpcStream(
271 event, debug_url, num_chunks - 1 == i);
272 if (!s.ok()) {
273 return errors::FailedPrecondition(
274 "Failed to send chunk ", i, " of ", num_chunks,
275 " of encoded GraphDef of size ", encoded_graph_def.size(), " bytes, ",
276 "due to: ", s.error_message());
277 }
278 }
279 return Status::OK();
280 }
281 #endif // #ifndef PLATFORM_WINDOWS
282
283 } // namespace
284
285 const char* const DebugIO::kDebuggerPluginName = "debugger";
286
287 const char* const DebugIO::kCoreMetadataTag = "core_metadata_";
288
289 const char* const DebugIO::kGraphTag = "graph_";
290
291 const char* const DebugIO::kHashTag = "hash";
292
ReadEventFromFile(const string & dump_file_path,Event * event)293 Status ReadEventFromFile(const string& dump_file_path, Event* event) {
294 Env* env(Env::Default());
295
296 string content;
297 uint64 file_size = 0;
298
299 Status s = env->GetFileSize(dump_file_path, &file_size);
300 if (!s.ok()) {
301 return s;
302 }
303
304 content.resize(file_size);
305
306 std::unique_ptr<RandomAccessFile> file;
307 s = env->NewRandomAccessFile(dump_file_path, &file);
308 if (!s.ok()) {
309 return s;
310 }
311
312 StringPiece result;
313 s = file->Read(0, file_size, &result, &(content)[0]);
314 if (!s.ok()) {
315 return s;
316 }
317
318 event->ParseFromString(content);
319 return Status::OK();
320 }
321
322 const char* const DebugIO::kFileURLScheme = "file://";
323 const char* const DebugIO::kGrpcURLScheme = "grpc://";
324 const char* const DebugIO::kMemoryURLScheme = "memcbk://";
325
326 // Publishes debug metadata to a set of debug URLs.
PublishDebugMetadata(const int64 global_step,const int64 session_run_index,const int64 executor_step_index,const std::vector<string> & input_names,const std::vector<string> & output_names,const std::vector<string> & target_nodes,const std::unordered_set<string> & debug_urls)327 Status DebugIO::PublishDebugMetadata(
328 const int64 global_step, const int64 session_run_index,
329 const int64 executor_step_index, const std::vector<string>& input_names,
330 const std::vector<string>& output_names,
331 const std::vector<string>& target_nodes,
332 const std::unordered_set<string>& debug_urls) {
333 std::ostringstream oss;
334
335 // Construct a JSON string to carry the metadata.
336 oss << "{";
337 oss << "\"global_step\":" << global_step << ",";
338 oss << "\"session_run_index\":" << session_run_index << ",";
339 oss << "\"executor_step_index\":" << executor_step_index << ",";
340 oss << "\"input_names\":[";
341 for (size_t i = 0; i < input_names.size(); ++i) {
342 oss << "\"" << input_names[i] << "\"";
343 if (i < input_names.size() - 1) {
344 oss << ",";
345 }
346 }
347 oss << "],";
348 oss << "\"output_names\":[";
349 for (size_t i = 0; i < output_names.size(); ++i) {
350 oss << "\"" << output_names[i] << "\"";
351 if (i < output_names.size() - 1) {
352 oss << ",";
353 }
354 }
355 oss << "],";
356 oss << "\"target_nodes\":[";
357 for (size_t i = 0; i < target_nodes.size(); ++i) {
358 oss << "\"" << target_nodes[i] << "\"";
359 if (i < target_nodes.size() - 1) {
360 oss << ",";
361 }
362 }
363 oss << "]";
364 oss << "}";
365
366 const string json_metadata = oss.str();
367 Event event;
368 event.set_wall_time(static_cast<double>(Env::Default()->NowMicros()));
369 LogMessage* log_message = event.mutable_log_message();
370 log_message->set_message(json_metadata);
371
372 Status status;
373 for (const string& url : debug_urls) {
374 if (str_util::Lowercase(url).find(kGrpcURLScheme) == 0) {
375 #ifndef PLATFORM_WINDOWS
376 Event grpc_event;
377
378 // Determine the path (if any) in the grpc:// URL, and add it as a field
379 // of the JSON string.
380 const string address = url.substr(strlen(DebugIO::kFileURLScheme));
381 const string path = address.find("/") == string::npos
382 ? ""
383 : address.substr(address.find("/"));
384 grpc_event.set_wall_time(event.wall_time());
385 LogMessage* log_message_grpc = grpc_event.mutable_log_message();
386 log_message_grpc->set_message(
387 strings::StrCat(json_metadata.substr(0, json_metadata.size() - 1),
388 ",\"grpc_path\":\"", path, "\"}"));
389
390 status.Update(
391 DebugGrpcIO::SendEventProtoThroughGrpcStream(grpc_event, url, true));
392 #else
393 GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
394 #endif
395 } else if (str_util::Lowercase(url).find(kFileURLScheme) == 0) {
396 const string dump_root_dir = url.substr(strlen(kFileURLScheme));
397 const string core_metadata_path = AppendTimestampToFilePath(
398 io::JoinPath(
399 dump_root_dir,
400 strings::StrCat(DebugNodeKey::kMetadataFilePrefix,
401 DebugIO::kCoreMetadataTag, "sessionrun",
402 strings::Printf("%.14lld", session_run_index))),
403 Env::Default()->NowMicros());
404 status.Update(DebugFileIO::DumpEventProtoToFile(
405 event, string(io::Dirname(core_metadata_path)),
406 string(io::Basename(core_metadata_path))));
407 }
408 }
409
410 return status;
411 }
412
PublishDebugTensor(const DebugNodeKey & debug_node_key,const Tensor & tensor,const uint64 wall_time_us,const gtl::ArraySlice<string> & debug_urls,const bool gated_grpc)413 Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key,
414 const Tensor& tensor,
415 const uint64 wall_time_us,
416 const gtl::ArraySlice<string>& debug_urls,
417 const bool gated_grpc) {
418 int32 num_failed_urls = 0;
419 std::vector<Status> fail_statuses;
420 for (const string& url : debug_urls) {
421 if (str_util::Lowercase(url).find(kFileURLScheme) == 0) {
422 const string dump_root_dir = url.substr(strlen(kFileURLScheme));
423
424 const int64 tensorBytes =
425 tensor.IsInitialized() ? tensor.TotalBytes() : 0;
426 if (!DebugFileIO::requestDiskByteUsage(tensorBytes)) {
427 return errors::ResourceExhausted(
428 "TensorFlow Debugger has exhausted file-system byte-size "
429 "allowance (",
430 DebugFileIO::globalDiskBytesLimit, "), therefore it cannot ",
431 "dump an additional ", tensorBytes, " byte(s) of tensor data ",
432 "for the debug tensor ", debug_node_key.node_name, ":",
433 debug_node_key.output_slot, ". You may use the environment ",
434 "variable TFDBG_DISK_BYTES_LIMIT to set a higher limit.");
435 }
436
437 Status s = DebugFileIO::DumpTensorToDir(
438 debug_node_key, tensor, wall_time_us, dump_root_dir, nullptr);
439 if (!s.ok()) {
440 num_failed_urls++;
441 fail_statuses.push_back(s);
442 }
443 } else if (str_util::Lowercase(url).find(kGrpcURLScheme) == 0) {
444 #ifndef PLATFORM_WINDOWS
445 Status s = DebugGrpcIO::SendTensorThroughGrpcStream(
446 debug_node_key, tensor, wall_time_us, url, gated_grpc);
447
448 if (!s.ok()) {
449 num_failed_urls++;
450 fail_statuses.push_back(s);
451 }
452 #else
453 GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
454 #endif
455 } else if (str_util::Lowercase(url).find(kMemoryURLScheme) == 0) {
456 const string dump_root_dir = url.substr(strlen(kMemoryURLScheme));
457 auto* callback_registry = DebugCallbackRegistry::singleton();
458 auto* callback = callback_registry->GetCallback(dump_root_dir);
459 CHECK(callback) << "No callback registered for: " << dump_root_dir;
460 (*callback)(debug_node_key, tensor);
461 } else {
462 return Status(error::UNAVAILABLE,
463 strings::StrCat("Invalid debug target URL: ", url));
464 }
465 }
466
467 if (num_failed_urls == 0) {
468 return Status::OK();
469 } else {
470 string error_message = strings::StrCat(
471 "Publishing to ", num_failed_urls, " of ", debug_urls.size(),
472 " debug target URLs failed, due to the following errors:");
473 for (Status& status : fail_statuses) {
474 error_message =
475 strings::StrCat(error_message, " ", status.error_message(), ";");
476 }
477
478 return Status(error::INTERNAL, error_message);
479 }
480 }
481
PublishDebugTensor(const DebugNodeKey & debug_node_key,const Tensor & tensor,const uint64 wall_time_us,const gtl::ArraySlice<string> & debug_urls)482 Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key,
483 const Tensor& tensor,
484 const uint64 wall_time_us,
485 const gtl::ArraySlice<string>& debug_urls) {
486 return PublishDebugTensor(debug_node_key, tensor, wall_time_us, debug_urls,
487 false);
488 }
489
PublishGraph(const Graph & graph,const string & device_name,const std::unordered_set<string> & debug_urls)490 Status DebugIO::PublishGraph(const Graph& graph, const string& device_name,
491 const std::unordered_set<string>& debug_urls) {
492 GraphDef graph_def;
493 graph.ToGraphDef(&graph_def);
494
495 string buf;
496 graph_def.SerializeToString(&buf);
497
498 const int64 now_micros = Env::Default()->NowMicros();
499 Event event;
500 event.set_wall_time(static_cast<double>(now_micros));
501 event.set_graph_def(buf);
502
503 Status status = Status::OK();
504 for (const string& debug_url : debug_urls) {
505 if (debug_url.find(kFileURLScheme) == 0) {
506 const string dump_root_dir =
507 io::JoinPath(debug_url.substr(strlen(kFileURLScheme)),
508 DebugNodeKey::DeviceNameToDevicePath(device_name));
509 const uint64 graph_hash = ::tensorflow::Hash64(buf);
510 const string file_name =
511 strings::StrCat(DebugNodeKey::kMetadataFilePrefix, DebugIO::kGraphTag,
512 DebugIO::kHashTag, graph_hash, "_", now_micros);
513
514 status.Update(
515 DebugFileIO::DumpEventProtoToFile(event, dump_root_dir, file_name));
516 } else if (debug_url.find(kGrpcURLScheme) == 0) {
517 #ifndef PLATFORM_WINDOWS
518 status.Update(PublishEncodedGraphDefInChunks(buf, device_name, now_micros,
519 debug_url));
520 #else
521 GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
522 #endif
523 }
524 }
525
526 return status;
527 }
528
IsCopyNodeGateOpen(const std::vector<DebugWatchAndURLSpec> & specs)529 bool DebugIO::IsCopyNodeGateOpen(
530 const std::vector<DebugWatchAndURLSpec>& specs) {
531 #ifndef PLATFORM_WINDOWS
532 for (const DebugWatchAndURLSpec& spec : specs) {
533 if (!spec.gated_grpc || spec.url.compare(0, strlen(DebugIO::kGrpcURLScheme),
534 DebugIO::kGrpcURLScheme)) {
535 return true;
536 } else {
537 if (DebugGrpcIO::IsReadGateOpen(spec.url, spec.watch_key)) {
538 return true;
539 }
540 }
541 }
542 return false;
543 #else
544 return true;
545 #endif
546 }
547
IsDebugNodeGateOpen(const string & watch_key,const std::vector<string> & debug_urls)548 bool DebugIO::IsDebugNodeGateOpen(const string& watch_key,
549 const std::vector<string>& debug_urls) {
550 #ifndef PLATFORM_WINDOWS
551 for (const string& debug_url : debug_urls) {
552 if (debug_url.compare(0, strlen(DebugIO::kGrpcURLScheme),
553 DebugIO::kGrpcURLScheme)) {
554 return true;
555 } else {
556 if (DebugGrpcIO::IsReadGateOpen(debug_url, watch_key)) {
557 return true;
558 }
559 }
560 }
561 return false;
562 #else
563 return true;
564 #endif
565 }
566
IsDebugURLGateOpen(const string & watch_key,const string & debug_url)567 bool DebugIO::IsDebugURLGateOpen(const string& watch_key,
568 const string& debug_url) {
569 #ifndef PLATFORM_WINDOWS
570 if (debug_url.find(kGrpcURLScheme) != 0) {
571 return true;
572 } else {
573 return DebugGrpcIO::IsReadGateOpen(debug_url, watch_key);
574 }
575 #else
576 return true;
577 #endif
578 }
579
CloseDebugURL(const string & debug_url)580 Status DebugIO::CloseDebugURL(const string& debug_url) {
581 if (debug_url.find(DebugIO::kGrpcURLScheme) == 0) {
582 #ifndef PLATFORM_WINDOWS
583 return DebugGrpcIO::CloseGrpcStream(debug_url);
584 #else
585 GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
586 #endif
587 } else {
588 // No-op for non-gRPC URLs.
589 return Status::OK();
590 }
591 }
592
DumpTensorToDir(const DebugNodeKey & debug_node_key,const Tensor & tensor,const uint64 wall_time_us,const string & dump_root_dir,string * dump_file_path)593 Status DebugFileIO::DumpTensorToDir(const DebugNodeKey& debug_node_key,
594 const Tensor& tensor,
595 const uint64 wall_time_us,
596 const string& dump_root_dir,
597 string* dump_file_path) {
598 const string file_path =
599 GetDumpFilePath(dump_root_dir, debug_node_key, wall_time_us);
600
601 if (dump_file_path != nullptr) {
602 *dump_file_path = file_path;
603 }
604
605 return DumpTensorToEventFile(debug_node_key, tensor, wall_time_us, file_path);
606 }
607
GetDumpFilePath(const string & dump_root_dir,const DebugNodeKey & debug_node_key,const uint64 wall_time_us)608 string DebugFileIO::GetDumpFilePath(const string& dump_root_dir,
609 const DebugNodeKey& debug_node_key,
610 const uint64 wall_time_us) {
611 return AppendTimestampToFilePath(
612 io::JoinPath(dump_root_dir, debug_node_key.device_path,
613 strings::StrCat(debug_node_key.node_name, "_",
614 debug_node_key.output_slot, "_",
615 debug_node_key.debug_op)),
616 wall_time_us);
617 }
618
DumpEventProtoToFile(const Event & event_proto,const string & dir_name,const string & file_name)619 Status DebugFileIO::DumpEventProtoToFile(const Event& event_proto,
620 const string& dir_name,
621 const string& file_name) {
622 Env* env(Env::Default());
623
624 Status s = RecursiveCreateDir(env, dir_name);
625 if (!s.ok()) {
626 return Status(error::FAILED_PRECONDITION,
627 strings::StrCat("Failed to create directory ", dir_name,
628 ", due to: ", s.error_message()));
629 }
630
631 const string file_path = io::JoinPath(dir_name, file_name);
632
633 string event_str;
634 event_proto.SerializeToString(&event_str);
635
636 std::unique_ptr<WritableFile> f = nullptr;
637 TF_CHECK_OK(env->NewWritableFile(file_path, &f));
638 f->Append(event_str).IgnoreError();
639 TF_CHECK_OK(f->Close());
640
641 return Status::OK();
642 }
643
DumpTensorToEventFile(const DebugNodeKey & debug_node_key,const Tensor & tensor,const uint64 wall_time_us,const string & file_path)644 Status DebugFileIO::DumpTensorToEventFile(const DebugNodeKey& debug_node_key,
645 const Tensor& tensor,
646 const uint64 wall_time_us,
647 const string& file_path) {
648 std::vector<Event> events;
649 TF_RETURN_IF_ERROR(
650 WrapTensorAsEvents(debug_node_key, tensor, wall_time_us, 0, &events));
651 return DumpEventProtoToFile(events[0], string(io::Dirname(file_path)),
652 string(io::Basename(file_path)));
653 }
654
RecursiveCreateDir(Env * env,const string & dir)655 Status DebugFileIO::RecursiveCreateDir(Env* env, const string& dir) {
656 if (env->FileExists(dir).ok() && env->IsDirectory(dir).ok()) {
657 // The path already exists as a directory. Return OK right away.
658 return Status::OK();
659 }
660
661 string parent_dir(io::Dirname(dir));
662 if (!env->FileExists(parent_dir).ok()) {
663 // The parent path does not exist yet, create it first.
664 Status s = RecursiveCreateDir(env, parent_dir); // Recursive call
665 if (!s.ok()) {
666 return Status(
667 error::FAILED_PRECONDITION,
668 strings::StrCat("Failed to create directory ", parent_dir));
669 }
670 } else if (env->FileExists(parent_dir).ok() &&
671 !env->IsDirectory(parent_dir).ok()) {
672 // The path exists, but it is a file.
673 return Status(error::FAILED_PRECONDITION,
674 strings::StrCat("Failed to create directory ", parent_dir,
675 " because the path exists as a file "));
676 }
677
678 env->CreateDir(dir).IgnoreError();
679 // Guard against potential race in creating directories by doing a check
680 // after the CreateDir call.
681 if (env->FileExists(dir).ok() && env->IsDirectory(dir).ok()) {
682 return Status::OK();
683 } else {
684 return Status(error::ABORTED,
685 strings::StrCat("Failed to create directory ", parent_dir));
686 }
687 }
688
689 // Default total disk usage limit: 100 GBytes
690 const uint64 DebugFileIO::defaultGlobalDiskBytesLimit = 107374182400L;
691 uint64 DebugFileIO::globalDiskBytesLimit = 0;
692 uint64 DebugFileIO::diskBytesUsed = 0;
693
694 mutex DebugFileIO::bytes_mu(LINKER_INITIALIZED);
695
requestDiskByteUsage(uint64 bytes)696 bool DebugFileIO::requestDiskByteUsage(uint64 bytes) {
697 mutex_lock l(bytes_mu);
698 if (globalDiskBytesLimit == 0) {
699 const char* env_tfdbg_disk_bytes_limit = getenv("TFDBG_DISK_BYTES_LIMIT");
700 if (env_tfdbg_disk_bytes_limit == nullptr ||
701 strlen(env_tfdbg_disk_bytes_limit) == 0) {
702 globalDiskBytesLimit = defaultGlobalDiskBytesLimit;
703 } else {
704 strings::safe_strtou64(string(env_tfdbg_disk_bytes_limit),
705 &globalDiskBytesLimit);
706 }
707 }
708
709 if (bytes == 0) {
710 return true;
711 }
712 if (diskBytesUsed + bytes < globalDiskBytesLimit) {
713 diskBytesUsed += bytes;
714 return true;
715 } else {
716 return false;
717 }
718 }
719
resetDiskByteUsage()720 void DebugFileIO::resetDiskByteUsage() {
721 mutex_lock l(bytes_mu);
722 diskBytesUsed = 0;
723 }
724
725 #ifndef PLATFORM_WINDOWS
DebugGrpcChannel(const string & server_stream_addr)726 DebugGrpcChannel::DebugGrpcChannel(const string& server_stream_addr)
727 : server_stream_addr_(server_stream_addr),
728 url_(strings::StrCat(DebugIO::kGrpcURLScheme, server_stream_addr)) {}
729
Connect(const int64 timeout_micros)730 Status DebugGrpcChannel::Connect(const int64 timeout_micros) {
731 ::grpc::ChannelArguments args;
732 args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits<int32>::max());
733 // Avoid problems where default reconnect backoff is too long (e.g., 20 s).
734 args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 1000);
735 channel_ = ::grpc::CreateCustomChannel(
736 server_stream_addr_, ::grpc::InsecureChannelCredentials(), args);
737 if (!channel_->WaitForConnected(
738 gpr_time_add(gpr_now(GPR_CLOCK_REALTIME),
739 gpr_time_from_micros(timeout_micros, GPR_TIMESPAN)))) {
740 return errors::FailedPrecondition(
741 "Failed to connect to gRPC channel at ", server_stream_addr_,
742 " within a timeout of ", timeout_micros / 1e6, " s.");
743 }
744 stub_ = EventListener::NewStub(channel_);
745 reader_writer_ = stub_->SendEvents(&ctx_);
746
747 return Status::OK();
748 }
749
WriteEvent(const Event & event)750 bool DebugGrpcChannel::WriteEvent(const Event& event) {
751 mutex_lock l(mu_);
752 return reader_writer_->Write(event);
753 }
754
ReadEventReply(EventReply * event_reply)755 bool DebugGrpcChannel::ReadEventReply(EventReply* event_reply) {
756 mutex_lock l(mu_);
757 return reader_writer_->Read(event_reply);
758 }
759
ReceiveAndProcessEventReplies(const size_t max_replies)760 void DebugGrpcChannel::ReceiveAndProcessEventReplies(const size_t max_replies) {
761 EventReply event_reply;
762 size_t num_replies = 0;
763 while ((max_replies == 0 || ++num_replies <= max_replies) &&
764 ReadEventReply(&event_reply)) {
765 for (const EventReply::DebugOpStateChange& debug_op_state_change :
766 event_reply.debug_op_state_changes()) {
767 string watch_key = strings::StrCat(debug_op_state_change.node_name(), ":",
768 debug_op_state_change.output_slot(),
769 ":", debug_op_state_change.debug_op());
770 DebugGrpcIO::SetDebugNodeKeyGrpcState(url_, watch_key,
771 debug_op_state_change.state());
772 }
773 }
774 }
775
ReceiveServerRepliesAndClose()776 Status DebugGrpcChannel::ReceiveServerRepliesAndClose() {
777 reader_writer_->WritesDone();
778 // Read all EventReply messages (if any) from the server.
779 ReceiveAndProcessEventReplies(0);
780
781 if (reader_writer_->Finish().ok()) {
782 return Status::OK();
783 } else {
784 return Status(error::FAILED_PRECONDITION,
785 "Failed to close debug GRPC stream.");
786 }
787 }
788
789 mutex DebugGrpcIO::streams_mu(LINKER_INITIALIZED);
790
791 int64 DebugGrpcIO::channel_connection_timeout_micros = 900 * 1000 * 1000;
792 // TODO(cais): Make this configurable?
793
794 const size_t DebugGrpcIO::kGrpcMessageSizeLimitBytes = 4000 * 1024;
795
796 const size_t DebugGrpcIO::kGrpcMaxVarintLengthSize = 6;
797
798 std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>*
GetStreamChannels()799 DebugGrpcIO::GetStreamChannels() {
800 static std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>*
801 stream_channels =
802 new std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>();
803 return stream_channels;
804 }
805
SendTensorThroughGrpcStream(const DebugNodeKey & debug_node_key,const Tensor & tensor,const uint64 wall_time_us,const string & grpc_stream_url,const bool gated)806 Status DebugGrpcIO::SendTensorThroughGrpcStream(
807 const DebugNodeKey& debug_node_key, const Tensor& tensor,
808 const uint64 wall_time_us, const string& grpc_stream_url,
809 const bool gated) {
810 if (gated &&
811 !IsReadGateOpen(grpc_stream_url, debug_node_key.debug_node_name)) {
812 return Status::OK();
813 } else {
814 std::vector<Event> events;
815 TF_RETURN_IF_ERROR(WrapTensorAsEvents(debug_node_key, tensor, wall_time_us,
816 kGrpcMessageSizeLimitBytes, &events));
817 for (const Event& event : events) {
818 TF_RETURN_IF_ERROR(
819 SendEventProtoThroughGrpcStream(event, grpc_stream_url));
820 }
821 if (IsWriteGateOpen(grpc_stream_url, debug_node_key.debug_node_name)) {
822 DebugGrpcChannel* debug_grpc_channel = nullptr;
823 TF_RETURN_IF_ERROR(
824 GetOrCreateDebugGrpcChannel(grpc_stream_url, &debug_grpc_channel));
825 debug_grpc_channel->ReceiveAndProcessEventReplies(1);
826 // TODO(cais): Support new tensor value carried in the EventReply for
827 // overriding the value of the tensor being published.
828 }
829 return Status::OK();
830 }
831 }
832
ReceiveEventReplyProtoThroughGrpcStream(EventReply * event_reply,const string & grpc_stream_url)833 Status DebugGrpcIO::ReceiveEventReplyProtoThroughGrpcStream(
834 EventReply* event_reply, const string& grpc_stream_url) {
835 DebugGrpcChannel* debug_grpc_channel = nullptr;
836 TF_RETURN_IF_ERROR(
837 GetOrCreateDebugGrpcChannel(grpc_stream_url, &debug_grpc_channel));
838 if (debug_grpc_channel->ReadEventReply(event_reply)) {
839 return Status::OK();
840 } else {
841 return errors::Cancelled(strings::StrCat(
842 "Reading EventReply from stream URL ", grpc_stream_url, " failed."));
843 }
844 }
845
GetOrCreateDebugGrpcChannel(const string & grpc_stream_url,DebugGrpcChannel ** debug_grpc_channel)846 Status DebugGrpcIO::GetOrCreateDebugGrpcChannel(
847 const string& grpc_stream_url, DebugGrpcChannel** debug_grpc_channel) {
848 const string addr_with_path =
849 grpc_stream_url.find(DebugIO::kGrpcURLScheme) == 0
850 ? grpc_stream_url.substr(strlen(DebugIO::kGrpcURLScheme))
851 : grpc_stream_url;
852 const string server_stream_addr =
853 addr_with_path.substr(0, addr_with_path.find('/'));
854 {
855 mutex_lock l(streams_mu);
856 std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>*
857 stream_channels = GetStreamChannels();
858 if (stream_channels->find(grpc_stream_url) == stream_channels->end()) {
859 std::unique_ptr<DebugGrpcChannel> channel(
860 new DebugGrpcChannel(server_stream_addr));
861 TF_RETURN_IF_ERROR(channel->Connect(channel_connection_timeout_micros));
862 stream_channels->insert(
863 std::make_pair(grpc_stream_url, std::move(channel)));
864 }
865 *debug_grpc_channel = (*stream_channels)[grpc_stream_url].get();
866 }
867 return Status::OK();
868 }
869
SendEventProtoThroughGrpcStream(const Event & event_proto,const string & grpc_stream_url,const bool receive_reply)870 Status DebugGrpcIO::SendEventProtoThroughGrpcStream(
871 const Event& event_proto, const string& grpc_stream_url,
872 const bool receive_reply) {
873 DebugGrpcChannel* debug_grpc_channel;
874 TF_RETURN_IF_ERROR(
875 GetOrCreateDebugGrpcChannel(grpc_stream_url, &debug_grpc_channel));
876
877 bool write_ok = debug_grpc_channel->WriteEvent(event_proto);
878 if (!write_ok) {
879 return errors::Cancelled(strings::StrCat("Write event to stream URL ",
880 grpc_stream_url, " failed."));
881 }
882
883 if (receive_reply) {
884 debug_grpc_channel->ReceiveAndProcessEventReplies(1);
885 }
886
887 return Status::OK();
888 }
889
IsReadGateOpen(const string & grpc_debug_url,const string & watch_key)890 bool DebugGrpcIO::IsReadGateOpen(const string& grpc_debug_url,
891 const string& watch_key) {
892 const DebugNodeName2State* enabled_node_to_state =
893 GetEnabledDebugOpStatesAtUrl(grpc_debug_url);
894 return enabled_node_to_state->find(watch_key) != enabled_node_to_state->end();
895 }
896
IsWriteGateOpen(const string & grpc_debug_url,const string & watch_key)897 bool DebugGrpcIO::IsWriteGateOpen(const string& grpc_debug_url,
898 const string& watch_key) {
899 const DebugNodeName2State* enabled_node_to_state =
900 GetEnabledDebugOpStatesAtUrl(grpc_debug_url);
901 auto it = enabled_node_to_state->find(watch_key);
902 if (it == enabled_node_to_state->end()) {
903 return false;
904 } else {
905 return it->second == EventReply::DebugOpStateChange::READ_WRITE;
906 }
907 }
908
CloseGrpcStream(const string & grpc_stream_url)909 Status DebugGrpcIO::CloseGrpcStream(const string& grpc_stream_url) {
910 mutex_lock l(streams_mu);
911
912 std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>*
913 stream_channels = GetStreamChannels();
914 if (stream_channels->find(grpc_stream_url) != stream_channels->end()) {
915 // Stream of the specified address exists. Close it and remove it from
916 // record.
917 Status s =
918 (*stream_channels)[grpc_stream_url]->ReceiveServerRepliesAndClose();
919 (*stream_channels).erase(grpc_stream_url);
920 return s;
921 } else {
922 // Stream of the specified address does not exist. No action.
923 return Status::OK();
924 }
925 }
926
927 std::unordered_map<string, DebugGrpcIO::DebugNodeName2State>*
GetEnabledDebugOpStates()928 DebugGrpcIO::GetEnabledDebugOpStates() {
929 static std::unordered_map<string, DebugNodeName2State>*
930 enabled_debug_op_states =
931 new std::unordered_map<string, DebugNodeName2State>();
932 return enabled_debug_op_states;
933 }
934
GetEnabledDebugOpStatesAtUrl(const string & grpc_debug_url)935 DebugGrpcIO::DebugNodeName2State* DebugGrpcIO::GetEnabledDebugOpStatesAtUrl(
936 const string& grpc_debug_url) {
937 static mutex* debug_ops_state_mu = new mutex();
938 std::unordered_map<string, DebugNodeName2State>* states =
939 GetEnabledDebugOpStates();
940
941 mutex_lock l(*debug_ops_state_mu);
942 if (states->find(grpc_debug_url) == states->end()) {
943 DebugNodeName2State url_enabled_debug_op_states;
944 (*states)[grpc_debug_url] = url_enabled_debug_op_states;
945 }
946 return &(*states)[grpc_debug_url];
947 }
948
SetDebugNodeKeyGrpcState(const string & grpc_debug_url,const string & watch_key,const EventReply::DebugOpStateChange::State new_state)949 void DebugGrpcIO::SetDebugNodeKeyGrpcState(
950 const string& grpc_debug_url, const string& watch_key,
951 const EventReply::DebugOpStateChange::State new_state) {
952 DebugNodeName2State* states = GetEnabledDebugOpStatesAtUrl(grpc_debug_url);
953 if (new_state == EventReply::DebugOpStateChange::DISABLED) {
954 if (states->find(watch_key) == states->end()) {
955 LOG(ERROR) << "Attempt to disable a watch key that is not currently "
956 << "enabled at " << grpc_debug_url << ": " << watch_key;
957 } else {
958 states->erase(watch_key);
959 }
960 } else if (new_state != EventReply::DebugOpStateChange::STATE_UNSPECIFIED) {
961 (*states)[watch_key] = new_state;
962 }
963 }
964
ClearEnabledWatchKeys()965 void DebugGrpcIO::ClearEnabledWatchKeys() {
966 GetEnabledDebugOpStates()->clear();
967 }
968
969 #endif // #ifndef PLATFORM_WINDOWS
970
971 } // namespace tensorflow
972