1 /** 2 * Copyright 2020-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_DEBUG_DEBUGGER_DEBUGGER_H_ 17 #define MINDSPORE_CCSRC_DEBUG_DEBUGGER_DEBUGGER_H_ 18 19 #include <list> 20 #include <memory> 21 #include <string> 22 #include <utility> 23 #include <vector> 24 #include <map> 25 #include "backend/session/kernel_graph.h" 26 #include "debug/debugger/grpc_client.h" 27 #include "debug/debug_services.h" 28 #include "common/trans.h" 29 30 using debugger::Chunk; 31 using debugger::DataType; 32 using debugger::EventReply; 33 using debugger::GraphProto; 34 using debugger::ModelProto; 35 using debugger::Statistics; 36 using debugger::TensorProto; 37 using debugger::WatchCondition; 38 using debugger::WatchCondition_Parameter; 39 using debugger::WatchNode; 40 using debugger::WatchpointHit; 41 42 template <class T> 43 using ProtoVector = google::protobuf::RepeatedPtrField<T>; 44 45 namespace mindspore { 46 // different types of command received by debugger 47 // need to keep sync with client-side proto and server-side proto 48 enum class DebuggerCommand { 49 kExitCMD = 2, 50 kRunCMD = 3, 51 kSetCMD = 4, 52 kViewCMD = 5, 53 kVersionMatchedCMD = 6, 54 kUnknownCMD = -1 55 }; 56 57 class Debugger : public std::enable_shared_from_this<Debugger> { 58 public: GetInstance()59 static std::shared_ptr<Debugger> GetInstance() { 60 std::lock_guard<std::mutex> i_lock(instance_lock_); 61 if (debugger_ == nullptr) { 62 debugger_ = std::shared_ptr<Debugger>(new (std::nothrow) Debugger()); 63 } 64 return debugger_; 65 } 66 67 // deconstructor 68 ~Debugger() = default; 69 70 // init 71 // only save device_id 72 void Init(const uint32_t device_id, const std::string device_target); 73 74 // reset debugger 75 void Reset(); 76 77 void PreExecuteGraphDebugger(const std::vector<KernelGraphPtr> &graphs); 78 // enable debugger 79 // send graph and wait for command 80 // do nothing if graph is set already 81 void PreExecute(const KernelGraphPtr &graph_ptr); 82 83 // analyze tensors and wait for command 84 // don't need a graph_ptr because it is saved during pre_execute 85 void PostExecute(); 86 87 bool DumpDataEnabledIteration() const; 88 89 static uint32_t GetRankID(); 90 91 void Dump(const KernelGraphPtr &kernel_graph) const; 92 93 void DumpSingleNode(const CNodePtr &node, uint32_t graph_id); 94 95 void DumpSetup(const KernelGraphPtr &kernel_graph) const; 96 97 void DumpInGraphCompiler(const KernelGraphPtr &kernel_graph); 98 99 void PostExecuteGraphDebugger(); 100 101 bool ReadNodeDataRequired(const CNodePtr &kernel) const; 102 103 void PostExecuteNode(const CNodePtr &kernel, bool last_kernel); 104 105 bool DumpTensorToFile(const std::string &tensor_name, bool trans_flag, const std::string &filepath, 106 const std::string &host_fmt, const std::vector<int64_t> &host_shape, TypeId host_type, 107 TypeId device_type, const std::string &addr_format, size_t slot) const; 108 109 bool LoadNewTensor(const std::shared_ptr<TensorData> &tensor, bool keep_prev); 110 111 bool debugger_enabled() const; 112 113 bool partial_memory() const; 114 115 void SetEnableHeartbeat(bool enabled); 116 117 void SetCurNode(const std::string &cur_name); 118 119 std::string run_level() const; 120 121 // check if any feature that uses the debugger backend is enabled 122 bool DebuggerBackendEnabled() const; 123 124 void SetTrainingDone(bool training_done); 125 126 // returns true if reply received and mindspore version matched with mindinsight version 127 // version_check should be true if you want the function to do backend compatibility check with Mindinsight 128 bool SendMetadata(bool version_check); 129 130 void LoadParametersAndConst(); 131 132 void LoadParametersAndConst(const KernelGraphPtr &graph); 133 134 void UpdateStepNum(const session::KernelGraph *graph); 135 136 void UpdateStepNumGPU(); 137 138 void ClearCurrentData(); 139 140 void LoadGraphOutputs(); 141 142 void CheckDatasetSinkMode(); 143 144 void LoadGraphs(const KernelGraphPtr &graph_ptr); 145 146 uint32_t GetFirstRunGraphId() const; 147 SetGraphPtr(const KernelGraphPtr & graph_ptr)148 void SetGraphPtr(const KernelGraphPtr &graph_ptr) { graph_ptr_ = graph_ptr; } 149 GetGraphPtr()150 const KernelGraphPtr GetGraphPtr() const { return graph_ptr_; } 151 GetGraphPtrList()152 const std::list<KernelGraphPtr> GetGraphPtrList() const { return graph_ptr_list_; } 153 154 bool TensorExistsInCurrent(const std::string &tensor_name); 155 156 // check if dump using debugger backend is enabled 157 bool CheckDebuggerDumpEnabled() const; 158 159 private: 160 // private constructor for singleton 161 Debugger(); 162 163 // enable debugger 164 // instantiate class members 165 // read env variable for grpc client 166 void EnableDebugger(); 167 168 // check if debugger enabled 169 bool CheckDebuggerEnabled() const; 170 171 void CheckDebuggerEnabledParam() const; 172 173 bool CheckDebuggerPartialMemoryEnabled() const; 174 175 // check and save graph pointer 176 void CheckGraphPtr(const KernelGraphPtr &graph_ptr); 177 178 // check if the graph is a dataset graph 179 void CheckDatasetGraph(); 180 181 // serialize graph and get proto 182 GraphProto GetGraphProto(const KernelGraphPtr &graph_ptr) const; 183 184 // send heartbeat message to UI once per 30 second by default 185 void SendHeartbeat(int32_t period); 186 187 // send graph and enter command wait loop 188 void SendGraphAndSuspend(const GraphProto &graph_proto); 189 190 void SendMultiGraphsAndSuspend(const std::list<GraphProto> &graph_proto_list); 191 192 // send multi_graphs and clear the graph_proto_list_ 193 void SendMultiGraphsAndClear(const KernelGraphPtr &graph_ptr); 194 195 // wait for command and process command 196 // send command request and process reply in a loop 197 // break if RunCMD 198 void CommandLoop(); 199 200 // Process the RunCMD 201 void ProcessRunCMD(const EventReply &reply); 202 // Process the KSetCMD 203 void ProcessKSetCMD(const EventReply &reply); 204 // Process the KViewCMD 205 void ProcessKViewCMD(const EventReply &reply); 206 // ViewCMD base level 207 void ViewBaseLevel(const EventReply &reply); 208 // ViewCMD statistics level 209 void ViewStatLevel(const EventReply &reply); 210 // ViewCMD value level 211 void ViewValueLevel(const EventReply &reply); 212 // set what nodes and conditions to watch 213 void SetWatchpoint(const ProtoVector<WatchNode> &nodes, const WatchCondition &condition, const int32_t id, 214 const ProtoVector<WatchCondition_Parameter> ¶meters); 215 216 // remove watchpoint with id 217 void RemoveWatchpoint(const int32_t id); 218 219 // load tensor for view command 220 std::list<TensorProto> LoadTensors(const ProtoVector<TensorProto> &tensors) const; 221 222 // load tensor base for view command 223 std::list<TensorBase> LoadTensorsBase(const ProtoVector<TensorProto> &tensors) const; 224 225 // load tensor statistics for view command 226 std::list<TensorSummary> LoadTensorsStat(const ProtoVector<TensorProto> &tensors) const; 227 228 // terminate training process 229 void Exit(bool exit_success = false); 230 231 // analyze tensors and check watchpoint conditions 232 // return names of tensors and what condition they hit 233 std::list<WatchpointHit> CheckWatchpoints(const std::string &watchnode = std::string(), 234 const CNodePtr &kernel = nullptr, bool recheck = false); 235 236 // send watchpoints that hit 237 void SendWatchpoints(const std::list<WatchpointHit> &points); 238 239 // Check if the port is valid 240 bool CheckPort(const std::string &port) const; 241 242 // Check if the IP is valid 243 bool CheckIp(const std::string &host) const; 244 245 void LoadSingleAnfnode(const AnfNodePtr &anf_node, const size_t output_index); 246 247 // class members 248 249 std::unique_ptr<GrpcClient> grpc_client_; 250 std::unique_ptr<DebugServices> debug_services_; 251 std::unique_ptr<std::thread> heartbeat_thread_; 252 KernelGraphPtr graph_ptr_; 253 uint32_t device_id_; 254 std::string device_target_; 255 int32_t num_step_; 256 bool debugger_enabled_; 257 bool suspended_at_last_kernel_; 258 std::string run_level_; 259 std::string node_name_; 260 std::string cur_name_; 261 bool training_done_; 262 bool is_dataset_graph_; 263 bool partial_memory_; 264 std::mutex access_lock_; 265 // flag to keep track of the very first suspension of debugger 266 bool initial_suspend_; 267 bool enable_heartbeat_; 268 269 std::list<GraphProto> graph_proto_list_; 270 std::list<KernelGraphPtr> graph_ptr_list_; 271 // The vector of graph pointers that have been run in the current step. 272 std::vector<KernelGraphPtr> graph_ptr_step_vec_; 273 274 // singleton 275 static std::mutex instance_lock_; 276 static std::shared_ptr<Debugger> debugger_; 277 uint32_t not_dataset_graph_sum_; 278 std::list<uint32_t> rungraph_id_list_; 279 std::string version_; 280 }; 281 282 using DebuggerPtr = std::shared_ptr<Debugger>; 283 // get debugger ModelProto 284 ModelProto GetDebuggerFuncGraphProto(const FuncGraphPtr &func_graph); 285 286 // for getting proto DataType from Type of Tensor 287 DataType GetDebuggerNumberDataType(const TypePtr &type); 288 289 // process reply and command type 290 DebuggerCommand GetCommand(const EventReply &reply); 291 292 // parse other data out of EventReply 293 ProtoVector<WatchCondition_Parameter> GetParameters(const EventReply &reply); 294 ProtoVector<WatchNode> GetWatchnodes(const EventReply &reply); 295 std::string GetNodeName(const EventReply &reply); 296 std::string GetRunLevel(const EventReply &reply); 297 WatchCondition GetWatchcondition(const EventReply &reply); 298 int32_t GetWatchpointID(const EventReply &reply); 299 bool GetWatchpointDelete(const EventReply &reply); 300 ProtoVector<TensorProto> GetTensors(const EventReply &reply); 301 bool GetMiVersionMatched(const EventReply &reply); 302 // get the full name of a tensor, which is the name used in TensorLoader 303 std::string GetTensorFullName(const TensorProto &tensor); 304 } // namespace mindspore 305 #endif // MINDSPORE_CCSRC_DEBUG_DEBUGGER_DEBUGGER_H_ 306