• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DEBUG_DEBUG_IO_UTILS_H_
17 #define TENSORFLOW_CORE_DEBUG_DEBUG_IO_UTILS_H_
18 
19 #include <cstddef>
20 #include <functional>
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <unordered_set>
25 #include <vector>
26 
27 #include "tensorflow/core/debug/debug_node_key.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/graph/graph.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/lib/gtl/array_slice.h"
32 #include "tensorflow/core/platform/env.h"
33 #include "tensorflow/core/util/event.pb.h"
34 
35 namespace tensorflow {
36 
37 Status ReadEventFromFile(const string& dump_file_path, Event* event);
38 
39 struct DebugWatchAndURLSpec {
DebugWatchAndURLSpecDebugWatchAndURLSpec40   DebugWatchAndURLSpec(const string& watch_key, const string& url,
41                        const bool gated_grpc)
42       : watch_key(watch_key), url(url), gated_grpc(gated_grpc) {}
43 
44   const string watch_key;
45   const string url;
46   const bool gated_grpc;
47 };
48 
49 // TODO(cais): Put static functions and members in a namespace, not a class.
50 class DebugIO {
51  public:
52   static const char* const kDebuggerPluginName;
53 
54   static const char* const kCoreMetadataTag;
55   static const char* const kGraphTag;
56   static const char* const kHashTag;
57 
58   static const char* const kFileURLScheme;
59   static const char* const kGrpcURLScheme;
60   static const char* const kMemoryURLScheme;
61 
62   static Status PublishDebugMetadata(
63       const int64 global_step, const int64 session_run_index,
64       const int64 executor_step_index, const std::vector<string>& input_names,
65       const std::vector<string>& output_names,
66       const std::vector<string>& target_nodes,
67       const std::unordered_set<string>& debug_urls);
68 
69   // Publishes a tensor to a debug target URL.
70   //
71   // Args:
72   //   debug_node_key: A DebugNodeKey identifying the debug node.
73   //   tensor: The Tensor object being published.
74   //   wall_time_us: Time stamp for the Tensor. Unit: microseconds (us).
75   //   debug_urls: An array of debug target URLs, e.g.,
76   //     "file:///foo/tfdbg_dump", "grpc://localhost:11011"
77   //   gated_grpc: Whether this call is subject to gRPC gating.
78   static Status PublishDebugTensor(const DebugNodeKey& debug_node_key,
79                                    const Tensor& tensor,
80                                    const uint64 wall_time_us,
81                                    const gtl::ArraySlice<string>& debug_urls,
82                                    const bool gated_grpc);
83 
84   // Convenience overload of the method above for no gated_grpc by default.
85   static Status PublishDebugTensor(const DebugNodeKey& debug_node_key,
86                                    const Tensor& tensor,
87                                    const uint64 wall_time_us,
88                                    const gtl::ArraySlice<string>& debug_urls);
89 
90   // Publishes a graph to a set of debug URLs.
91   //
92   // Args:
93   //   graph: The graph to be published.
94   //   debug_urls: The set of debug URLs to publish the graph to.
95   static Status PublishGraph(const Graph& graph, const string& device_name,
96                              const std::unordered_set<string>& debug_urls);
97 
98   // Determines whether a copy node needs to perform deep-copy of input tensor.
99   //
100   // The input arguments contain sufficient information about the attached
101   // downstream debug ops for this method to determine whether all the said
102   // ops are disabled given the current status of the gRPC gating.
103   //
104   // Args:
105   //   specs: A vector of DebugWatchAndURLSpec carrying information about the
106   //     debug ops attached to the Copy node, their debug URLs and whether
107   //     they have the attribute value gated_grpc == True.
108   //
109   // Returns:
110   //   Whether any of the attached downstream debug ops is enabled given the
111   //   current status of the gRPC gating.
112   static bool IsCopyNodeGateOpen(
113       const std::vector<DebugWatchAndURLSpec>& specs);
114 
115   // Determines whether a debug node needs to proceed given the current gRPC
116   // gating status.
117   //
118   // Args:
119   //   watch_key: debug tensor watch key, in the format of
120   //     tensor_name:debug_op, e.g., "Weights:0:DebugIdentity".
121   //   debug_urls: the debug URLs of the debug node.
122   //
123   // Returns:
124   //   Whether this debug op should proceed.
125   static bool IsDebugNodeGateOpen(const string& watch_key,
126                                   const std::vector<string>& debug_urls);
127 
128   // Determines whether debug information should be sent through a grpc://
129   // debug URL given the current gRPC gating status.
130   //
131   // Args:
132   //   watch_key: debug tensor watch key, in the format of
133   //     tensor_name:debug_op, e.g., "Weights:0:DebugIdentity".
134   //   debug_url: the debug URL, e.g., "grpc://localhost:3333",
135   //     "file:///tmp/tfdbg_1".
136   //
137   // Returns:
138   //   Whether the sending of debug data to the debug_url should
139   //     proceed.
140   static bool IsDebugURLGateOpen(const string& watch_key,
141                                  const string& debug_url);
142 
143   static Status CloseDebugURL(const string& debug_url);
144 };
145 
146 // Helper class for debug ops.
147 class DebugFileIO {
148  public:
149   // Encapsulates the Tensor in an Event protobuf and write it to a directory.
150   // The actual path of the dump file will be a contactenation of
151   // dump_root_dir, tensor_name, along with the wall_time.
152   //
153   // For example:
154   //   let dump_root_dir = "/tmp/tfdbg_dump",
155   //       node_name = "foo/bar",
156   //       output_slot = 0,
157   //       debug_op = DebugIdentity,
158   //       and wall_time_us = 1467891234512345,
159   // the dump file will be generated at path:
160   //   /tmp/tfdbg_dump/foo/bar_0_DebugIdentity_1467891234512345.
161   //
162   // Args:
163   //   debug_node_key: A DebugNodeKey identifying the debug node.
164   //   wall_time_us: Wall time at which the Tensor is generated during graph
165   //     execution. Unit: microseconds (us).
166   //   dump_root_dir: Root directory for dumping the tensor.
167   //   dump_file_path: The actual dump file path (passed as reference).
168   static Status DumpTensorToDir(const DebugNodeKey& debug_node_key,
169                                 const Tensor& tensor, const uint64 wall_time_us,
170                                 const string& dump_root_dir,
171                                 string* dump_file_path);
172 
173   // Get the full path to the dump file.
174   //
175   // Args:
176   //   dump_root_dir: The dump root directory, e.g., /tmp/tfdbg_dump
177   //   node_name: Name of the node from which the dumped tensor is generated,
178   //     e.g., foo/bar/node_a
179   //   output_slot: Output slot index of the said node, e.g., 0.
180   //   debug_op: Name of the debug op, e.g., DebugIdentity.
181   //   wall_time_us: Time stamp of the dumped tensor, in microseconds (us).
182   static string GetDumpFilePath(const string& dump_root_dir,
183                                 const DebugNodeKey& debug_node_key,
184                                 const uint64 wall_time_us);
185 
186   // Dumps an Event proto to a file.
187   //
188   // Args:
189   //   event_prot: The Event proto to be dumped.
190   //   dir_name: Directory path.
191   //   file_name: Base file name.
192   static Status DumpEventProtoToFile(const Event& event_proto,
193                                      const string& dir_name,
194                                      const string& file_name);
195 
196   // Request additional bytes to be dumped to the file system.
197   //
198   // Does not actually dump the bytes, but instead just performs the
199   // bookkeeping necessary to prevent the total dumped amount of data from
200   // exceeding the limit (default 100 GBytes or set customly through the
201   // environment variable TFDBG_DISK_BYTES_LIMIT).
202   //
203   // Args:
204   //   bytes: Number of bytes to request.
205   //
206   // Returns:
207   //   Whether the request is approved given the total dumping
208   //   limit.
209   static bool requestDiskByteUsage(uint64 bytes);
210 
211   // Reset the disk byte usage to zero.
212   static void resetDiskByteUsage();
213 
214   static uint64 globalDiskBytesLimit;
215 
216  private:
217   // Encapsulates the Tensor in an Event protobuf and write it to file.
218   static Status DumpTensorToEventFile(const DebugNodeKey& debug_node_key,
219                                       const Tensor& tensor,
220                                       const uint64 wall_time_us,
221                                       const string& file_path);
222 
223   // Implemented ad hoc here for now.
224   // TODO(cais): Replace with shared implementation once http://b/30497715 is
225   // fixed.
226   static Status RecursiveCreateDir(Env* env, const string& dir);
227 
228   // Tracks how much disk has been used so far.
229   static uint64 diskBytesUsed;
230   // Mutex for thread-safe access to diskBytesUsed.
231   static mutex bytes_mu;
232   // Default limit for the disk space.
233   static const uint64 defaultGlobalDiskBytesLimit;
234 
235   friend class DiskUsageLimitTest;
236 };
237 
238 }  // namespace tensorflow
239 
240 namespace std {
241 
242 template <>
243 struct hash<::tensorflow::DebugNodeKey> {
244   size_t operator()(const ::tensorflow::DebugNodeKey& k) const {
245     return ::tensorflow::Hash64(
246         ::tensorflow::strings::StrCat(k.device_name, ":", k.node_name, ":",
247                                       k.output_slot, ":", k.debug_op, ":"));
248   }
249 };
250 
251 }  // namespace std
252 
253 // TODO(cais): Support grpc:// debug URLs in open source once Python grpc
254 //   genrule becomes available. See b/23796275.
255 #ifndef PLATFORM_WINDOWS
256 #include "grpcpp/channel.h"
257 #include "tensorflow/core/debug/debug_service.grpc.pb.h"
258 
259 namespace tensorflow {
260 
261 class DebugGrpcChannel {
262  public:
263   // Constructor of DebugGrpcChannel.
264   //
265   // Args:
266   //   server_stream_addr: Address (host name and port) of the debug stream
267   //     server implementing the EventListener service (see
268   //     debug_service.proto). E.g., "127.0.0.1:12345".
269   DebugGrpcChannel(const string& server_stream_addr);
270 
271   virtual ~DebugGrpcChannel() {}
272 
273   // Attempt to establish connection with server.
274   //
275   // Args:
276   //   timeout_micros: Timeout (in microseconds) for the attempt to establish
277   //     the connection.
278   //
279   // Returns:
280   //   OK Status iff connection is successfully established before timeout,
281   //   otherwise return an error Status.
282   Status Connect(const int64 timeout_micros);
283 
284   // Write an Event proto to the debug gRPC stream.
285   //
286   // Thread-safety: Safe with respect to other calls to the same method and
287   //   calls to ReadEventReply() and Close().
288   //
289   // Args:
290   //   event: The event proto to be written to the stream.
291   //
292   // Returns:
293   //   True iff the write is successful.
294   bool WriteEvent(const Event& event);
295 
296   // Read an EventReply proto from the debug gRPC stream.
297   //
298   // This method blocks and waits for an EventReply from the server.
299   // Thread-safety: Safe with respect to other calls to the same method and
300   //   calls to WriteEvent() and Close().
301   //
302   // Args:
303   //   event_reply: the to-be-modified EventReply proto passed as reference.
304   //
305   // Returns:
306   //   True iff the read is successful.
307   bool ReadEventReply(EventReply* event_reply);
308 
309   // Receive and process EventReply protos from the gRPC debug server.
310   //
311   // The processing includes setting debug watch key states using the
312   // DebugOpStateChange fields of the EventReply.
313   //
314   // Args:
315   //   max_replies: Maximum number of replies to receive. Will receive all
316   //     remaining replies iff max_replies == 0.
317   void ReceiveAndProcessEventReplies(size_t max_replies);
318 
319   // Receive EventReplies from server (if any) and close the stream and the
320   // channel.
321   Status ReceiveServerRepliesAndClose();
322 
323  private:
324   string server_stream_addr_;
325   string url_;
326   ::grpc::ClientContext ctx_;
327   std::shared_ptr<::grpc::Channel> channel_;
328   std::unique_ptr<EventListener::Stub> stub_;
329   std::unique_ptr<::grpc::ClientReaderWriterInterface<Event, EventReply>>
330       reader_writer_;
331 
332   mutex mu_;
333 };
334 
335 class DebugGrpcIO {
336  public:
337   static const size_t kGrpcMessageSizeLimitBytes;
338   static const size_t kGrpcMaxVarintLengthSize;
339 
340   // Sends a tensor through a debug gRPC stream.
341   static Status SendTensorThroughGrpcStream(const DebugNodeKey& debug_node_key,
342                                             const Tensor& tensor,
343                                             const uint64 wall_time_us,
344                                             const string& grpc_stream_url,
345                                             const bool gated);
346 
347   // Sends an Event proto through a debug gRPC stream.
348   // Thread-safety: Safe with respect to other calls to the same method and
349   // calls to CloseGrpcStream().
350   //
351   // Args:
352   //   event_proto: The Event proto to be sent.
353   //   grpc_stream_url: The grpc:// URL of the stream to use, e.g.,
354   //     "grpc://localhost:11011", "localhost:22022".
355   //   receive_reply: Whether an EventReply proto will be read after event_proto
356   //     is sent and before the function returns.
357   //
358   // Returns:
359   //   The Status of the operation.
360   static Status SendEventProtoThroughGrpcStream(
361       const Event& event_proto, const string& grpc_stream_url,
362       const bool receive_reply = false);
363 
364   // Receive an EventReply proto through a debug gRPC stream.
365   static Status ReceiveEventReplyProtoThroughGrpcStream(
366       EventReply* event_reply, const string& grpc_stream_url);
367 
368   // Check whether a debug watch key is read-activated at a given gRPC URL.
369   static bool IsReadGateOpen(const string& grpc_debug_url,
370                              const string& watch_key);
371 
372   // Check whether a debug watch key is write-activated (i.e., read- and
373   // write-activated) at a given gRPC URL.
374   static bool IsWriteGateOpen(const string& grpc_debug_url,
375                               const string& watch_key);
376 
377   // Closes a gRPC stream to the given address, if it exists.
378   // Thread-safety: Safe with respect to other calls to the same method and
379   // calls to SendTensorThroughGrpcStream().
380   static Status CloseGrpcStream(const string& grpc_stream_url);
381 
382   // Set the gRPC state of a debug node key.
383   // TODO(cais): Include device information in watch_key.
384   static void SetDebugNodeKeyGrpcState(
385       const string& grpc_debug_url, const string& watch_key,
386       const EventReply::DebugOpStateChange::State new_state);
387 
388  private:
389   using DebugNodeName2State =
390       std::unordered_map<string, EventReply::DebugOpStateChange::State>;
391 
392   // Returns a global map from grpc debug URLs to the corresponding
393   // DebugGrpcChannels.
394   static std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>*
395   GetStreamChannels();
396 
397   // Get a DebugGrpcChannel object at a given URL, creating one if necessary.
398   //
399   // Args:
400   //   grpc_stream_url: grpc:// URL of the stream, e.g., "grpc://localhost:6064"
401   //   debug_grpc_channel: A pointer to the DebugGrpcChannel object, passed as a
402   //     a pointer to the pointer. The DebugGrpcChannel object is owned
403   //     statically elsewhere, not by the caller of this function.
404   //
405   // Returns:
406   //   Status of this operation.
407   static Status GetOrCreateDebugGrpcChannel(
408       const string& grpc_stream_url, DebugGrpcChannel** debug_grpc_channel);
409 
410   // Returns a map from debug URL to a map from debug op name to enabled state.
411   static std::unordered_map<string, DebugNodeName2State>*
412   GetEnabledDebugOpStates();
413 
414   // Returns a map from debug op names to enabled state, for a given debug URL.
415   static DebugNodeName2State* GetEnabledDebugOpStatesAtUrl(
416       const string& grpc_debug_url);
417 
418   // Clear enabled debug op state from all debug URLs (if any).
419   static void ClearEnabledWatchKeys();
420 
421   static mutex streams_mu;
422   static int64 channel_connection_timeout_micros;
423 
424   friend class GrpcDebugTest;
425   friend class DebugNumericSummaryOpTest;
426 };
427 
428 }  // namespace tensorflow
429 #endif  // #ifndef(PLATFORM_WINDOWS)
430 
431 #endif  // TENSORFLOW_CORE_DEBUG_DEBUG_IO_UTILS_H_
432