1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_ 18 19 #include <memory> 20 21 #include "tensorflow/core/common_runtime/device.h" 22 #include "tensorflow/core/graph/graph.h" 23 #include "tensorflow/core/lib/core/status.h" 24 #include "tensorflow/core/platform/macros.h" 25 #include "tensorflow/core/platform/protobuf.h" 26 #include "tensorflow/core/protobuf/debug.pb.h" 27 28 namespace tensorflow { 29 30 // Returns a summary string for the list of debug tensor watches. 31 const string SummarizeDebugTensorWatches( 32 const protobuf::RepeatedPtrField<DebugTensorWatch>& watches); 33 34 // An abstract interface for storing and retrieving debugging information. 35 class DebuggerStateInterface { 36 public: ~DebuggerStateInterface()37 virtual ~DebuggerStateInterface() {} 38 39 // Publish metadata about the debugged Session::Run() call. 40 // 41 // Args: 42 // global_step: A global step count supplied by the caller of 43 // Session::Run(). 44 // session_run_index: A chronologically sorted index for calls to the Run() 45 // method of the Session object. 46 // executor_step_index: A chronologically sorted index of invocations of the 47 // executor charged to serve this Session::Run() call. 48 // input_names: Name of the input Tensors (feed keys). 49 // output_names: Names of the fetched Tensors. 50 // target_names: Names of the target nodes. 51 virtual Status PublishDebugMetadata( 52 const int64 global_step, const int64 session_run_index, 53 const int64 executor_step_index, const std::vector<string>& input_names, 54 const std::vector<string>& output_names, 55 const std::vector<string>& target_nodes) = 0; 56 }; 57 58 class DebugGraphDecoratorInterface { 59 public: ~DebugGraphDecoratorInterface()60 virtual ~DebugGraphDecoratorInterface() {} 61 62 // Insert special-purpose debug nodes to graph and dump the graph for 63 // record. See the documentation of DebugNodeInserter::InsertNodes() for 64 // details. 65 virtual Status DecorateGraph(Graph* graph, Device* device) = 0; 66 67 // Publish Graph to debug URLs. 68 virtual Status PublishGraph(const Graph& graph, 69 const string& device_name) = 0; 70 }; 71 72 typedef std::function<std::unique_ptr<DebuggerStateInterface>( 73 const DebugOptions& options)> 74 DebuggerStateFactory; 75 76 // Contains only static methods for registering DebuggerStateFactory. 77 // We don't expect to create any instances of this class. 78 // Call DebuggerStateRegistry::RegisterFactory() at initialization time to 79 // define a global factory that creates instances of DebuggerState, then call 80 // DebuggerStateRegistry::CreateState() to create a single instance. 81 class DebuggerStateRegistry { 82 public: 83 // Registers a function that creates a concrete DebuggerStateInterface 84 // implementation based on DebugOptions. 85 static void RegisterFactory(const DebuggerStateFactory& factory); 86 87 // If RegisterFactory() has been called, creates and supplies a concrete 88 // DebuggerStateInterface implementation using the registered factory, 89 // owned by the caller and return an OK Status. Otherwise returns an error 90 // Status. 91 static Status CreateState(const DebugOptions& debug_options, 92 std::unique_ptr<DebuggerStateInterface>* state); 93 94 private: 95 static DebuggerStateFactory* factory_; 96 97 TF_DISALLOW_COPY_AND_ASSIGN(DebuggerStateRegistry); 98 }; 99 100 typedef std::function<std::unique_ptr<DebugGraphDecoratorInterface>( 101 const DebugOptions& options)> 102 DebugGraphDecoratorFactory; 103 104 class DebugGraphDecoratorRegistry { 105 public: 106 static void RegisterFactory(const DebugGraphDecoratorFactory& factory); 107 108 static Status CreateDecorator( 109 const DebugOptions& options, 110 std::unique_ptr<DebugGraphDecoratorInterface>* decorator); 111 112 private: 113 static DebugGraphDecoratorFactory* factory_; 114 115 TF_DISALLOW_COPY_AND_ASSIGN(DebugGraphDecoratorRegistry); 116 }; 117 118 } // end namespace tensorflow 119 120 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_ 121