• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_
18 
19 #include <memory>
20 
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/graph/graph.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/platform/macros.h"
25 #include "tensorflow/core/platform/protobuf.h"
26 #include "tensorflow/core/protobuf/debug.pb.h"
27 
28 namespace tensorflow {
29 
30 // Returns a summary string for the list of debug tensor watches.
31 const string SummarizeDebugTensorWatches(
32     const protobuf::RepeatedPtrField<DebugTensorWatch>& watches);
33 
34 // An abstract interface for storing and retrieving debugging information.
35 class DebuggerStateInterface {
36  public:
~DebuggerStateInterface()37   virtual ~DebuggerStateInterface() {}
38 
39   // Publish metadata about the debugged Session::Run() call.
40   //
41   // Args:
42   //   global_step: A global step count supplied by the caller of
43   //     Session::Run().
44   //   session_run_index: A chronologically sorted index for calls to the Run()
45   //     method of the Session object.
46   //   executor_step_index: A chronologically sorted index of invocations of the
47   //     executor charged to serve this Session::Run() call.
48   //   input_names: Name of the input Tensors (feed keys).
49   //   output_names: Names of the fetched Tensors.
50   //   target_names: Names of the target nodes.
51   virtual Status PublishDebugMetadata(
52       const int64 global_step, const int64 session_run_index,
53       const int64 executor_step_index, const std::vector<string>& input_names,
54       const std::vector<string>& output_names,
55       const std::vector<string>& target_nodes) = 0;
56 };
57 
58 class DebugGraphDecoratorInterface {
59  public:
~DebugGraphDecoratorInterface()60   virtual ~DebugGraphDecoratorInterface() {}
61 
62   // Insert special-purpose debug nodes to graph and dump the graph for
63   // record. See the documentation of DebugNodeInserter::InsertNodes() for
64   // details.
65   virtual Status DecorateGraph(Graph* graph, Device* device) = 0;
66 
67   // Publish Graph to debug URLs.
68   virtual Status PublishGraph(const Graph& graph,
69                               const string& device_name) = 0;
70 };
71 
72 typedef std::function<std::unique_ptr<DebuggerStateInterface>(
73     const DebugOptions& options)>
74     DebuggerStateFactory;
75 
76 // Contains only static methods for registering DebuggerStateFactory.
77 // We don't expect to create any instances of this class.
78 // Call DebuggerStateRegistry::RegisterFactory() at initialization time to
79 // define a global factory that creates instances of DebuggerState, then call
80 // DebuggerStateRegistry::CreateState() to create a single instance.
81 class DebuggerStateRegistry {
82  public:
83   // Registers a function that creates a concrete DebuggerStateInterface
84   // implementation based on DebugOptions.
85   static void RegisterFactory(const DebuggerStateFactory& factory);
86 
87   // If RegisterFactory() has been called, creates and supplies a concrete
88   // DebuggerStateInterface implementation using the registered factory,
89   // owned by the caller and return an OK Status. Otherwise returns an error
90   // Status.
91   static Status CreateState(const DebugOptions& debug_options,
92                             std::unique_ptr<DebuggerStateInterface>* state);
93 
94  private:
95   static DebuggerStateFactory* factory_;
96 
97   TF_DISALLOW_COPY_AND_ASSIGN(DebuggerStateRegistry);
98 };
99 
100 typedef std::function<std::unique_ptr<DebugGraphDecoratorInterface>(
101     const DebugOptions& options)>
102     DebugGraphDecoratorFactory;
103 
104 class DebugGraphDecoratorRegistry {
105  public:
106   static void RegisterFactory(const DebugGraphDecoratorFactory& factory);
107 
108   static Status CreateDecorator(
109       const DebugOptions& options,
110       std::unique_ptr<DebugGraphDecoratorInterface>* decorator);
111 
112  private:
113   static DebugGraphDecoratorFactory* factory_;
114 
115   TF_DISALLOW_COPY_AND_ASSIGN(DebugGraphDecoratorRegistry);
116 };
117 
118 }  // end namespace tensorflow
119 
120 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_
121