• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_DEBUG_DEBUGGER_DEBUGGER_H_
17 #define MINDSPORE_CCSRC_DEBUG_DEBUGGER_DEBUGGER_H_
18 
19 #include <list>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 #include <map>
25 #include <set>
26 #include "google/protobuf/repeated_field.h"
27 #include "include/backend/debug/tensor_data.h"
28 #include "include/backend/kernel_graph.h"
29 #include "include/backend/device_address.h"
30 #include "include/backend/visible.h"
31 
32 namespace debugger {
33 class Chunk;
34 class EventReply;
35 class GraphProto;
36 class ModelProto;
37 class Statistics;
38 class TensorProto;
39 class WatchCondition;
40 class WatchCondition_Parameter;
41 class WatchNode;
42 class WatchpointHit;
43 class TensorBase;
44 class TensorSummary;
45 enum DataType : int;
46 }  // namespace debugger
47 
48 template <class T>
49 using ProtoVector = google::protobuf::RepeatedPtrField<T>;
50 
51 namespace mindspore {
52 class GrpcClient;
53 class DebugServices;
54 // different types of command received by debugger
55 // need to keep sync with client-side proto and server-side proto
56 enum class DebuggerCommand {
57   kExitCMD = 2,
58   kRunCMD = 3,
59   kSetCMD = 4,
60   kViewCMD = 5,
61   kVersionMatchedCMD = 6,
62   kUnknownCMD = -1
63 };
64 
65 class BACKEND_EXPORT Debugger : public std::enable_shared_from_this<Debugger> {
66  public:
67   static std::shared_ptr<Debugger> GetInstance();
68 
69   // deconstructor
70   ~Debugger() = default;
71 
72   // init
73   // only save device_id
74   void Init(const uint32_t device_id, const std::string device_target);
75 
76   // reset debugger
77   void Reset();
78 
79   void PreExecuteGraphDebugger(const std::vector<KernelGraphPtr> &graphs,
80                                const std::vector<AnfNodePtr> &origin_parameters_order);
81   // enable debugger
82   // send graph and wait for command
83   // do nothing if graph is set already
84   void PreExecute(const KernelGraphPtr &graph_ptr);
85 
86   void SetCurrentAndPrevRootGraph(uint32_t root_graph_id);
87 
SetAscendKernelByKernelFlag(bool value)88   void SetAscendKernelByKernelFlag(bool value) { ascend_kernel_by_kernel_ = value; }
89 
GetAscendKernelByKernelFlag()90   bool GetAscendKernelByKernelFlag() const { return ascend_kernel_by_kernel_; }
91 
92   void StoreRunGraphIdList(uint32_t graph_id);
93 
94   // analyze tensors and wait for command
95   // don't need a graph_ptr because it is saved during pre_execute
96   void PostExecute();
97 
98   static uint32_t GetRankID();
99 
100   void DumpConstantDataAscend(const KernelGraphPtr &graph);
101 
102   void DumpSingleNode(const CNodePtr &node, uint32_t graph_id) const;
103 
104   void DumpInGraphCompiler(const KernelGraphPtr &kernel_graph);
105 
106   void PostExecuteGraphDebugger();
107 
108   bool ReadNodeDataRequired(const CNodePtr &kernel) const;
109 
110   void PostExecuteNode(const CNodePtr &kernel, bool last_kernel);
111 
112   bool DumpTensorToFile(const std::string &filepath, const std::string &tensor_name, size_t slot) const;
113 
114   bool LoadNewTensor(const std::shared_ptr<TensorData> &tensor, bool keep_prev);
115 
116   std::shared_ptr<TensorData> GetTensor(const std::string &tensor_name) const;
117 
118   bool debugger_enabled() const;
119 
120   bool partial_memory() const;
121 
122   void SetEnableHeartbeat(bool enabled);
123 
124   void SetCurNode(const std::string &cur_name);
125 
126   std::string run_level() const;
127 
128   // check if any feature that uses the debugger backend is enabled
129   bool DebuggerBackendEnabled() const;
130 
131   void SetTrainingDone(bool training_done);
132 
133   // returns true if reply received and mindspore version matched with mindInsight version
134   // version_check should be true if you want the function to do backend compatibility check with MindInsight
135   bool SendMetadata(bool version_check);
136 
137   bool CheckSendMetadata();
138 
139   void LoadParametersAndConst();
140 
141   void LoadParametersAndConst(const KernelGraphPtr &graph);
142 
143   void LoadParametersAllGraphs();
144 
145   void LoadConstsForGraph(const KernelGraphPtr &graph);
146 
147   void DumpParamsAndConstAndHistory();
148 
149   void UpdateStepNum(const session::KernelGraph *graph);
150 
151   void UpdateStepNumGPU();
152 
153   void ClearCurrentData();
154 
155   void LoadGraphOutputs();
156 
157   void CheckDatasetSinkMode(const KernelGraphPtr &graph_ptr);
158 
159   void LoadGraphs(const KernelGraphPtr &graph_ptr);
160 
161   uint32_t GetFirstRunGraphId() const;
162 
GetCurrentRootGraphId()163   uint32_t GetCurrentRootGraphId() const { return cur_root_graph_id_; }
164 
GetPrevRootGraphId()165   uint32_t GetPrevRootGraphId() const { return prev_root_graph_id_; }
166 
GetStepGraphPtrList()167   std::vector<KernelGraphPtr> GetStepGraphPtrList() const { return graph_ptr_step_vec_; }
168 
InsertExecutedGraph(const KernelGraphPtr & graph_ptr)169   void InsertExecutedGraph(const KernelGraphPtr &graph_ptr) { (void)executed_graph_ptr_set_.insert(graph_ptr); }
170 
SetGraphPtr(const KernelGraphPtr & graph_ptr)171   void SetGraphPtr(const KernelGraphPtr &graph_ptr) { graph_ptr_ = graph_ptr; }
172 
GetGraphPtr()173   const KernelGraphPtr GetGraphPtr() const { return graph_ptr_; }
174 
GetGraphPtrList()175   const std::list<KernelGraphPtr> GetGraphPtrList() const { return graph_ptr_list_; }
176 
177   bool TensorExistsInCurrent(const std::string &tensor_name);
178 
179   // check if dump using debugger backend is enabled
180   bool CheckDebuggerDumpEnabled() const;
181 
182   // check if debugger is enabled
183   bool CheckDebuggerEnabled() const;
184 
GetGraphIterMap()185   std::map<uint32_t, int32_t> GetGraphIterMap() { return graph_iter_num_map_; }
186 
187   void UpdateGraphIterMap(uint32_t graph_id, int32_t iter_num);
188 
GetParametersMindRT()189   std::vector<AnfNodePtr> GetParametersMindRT() const { return parameters_mindRT_; }
190 
191  private:
192   // private constructor for singleton
193   Debugger();
194 
195   // enable debugger
196   // instantiate class members
197   // read env variable for grpc client
198   void EnableDebugger();
199 
200   void CheckDebuggerEnabledParam() const;
201 
202   bool CheckDebuggerPartialMemoryEnabled() const;
203 
204   // check and save graph pointer
205   void CheckGraphPtr(const KernelGraphPtr &graph_ptr);
206 
207   // check if the graph is a dataset graph
208   void CheckDatasetGraph();
209 
210   // serialize graph and get proto
211   debugger::GraphProto GetGraphProto(const KernelGraphPtr &graph_ptr) const;
212 
213   // send heartbeat message to UI once per 30 second by default
214   void SendHeartbeat(int32_t period);
215 
216   // send graph and enter command wait loop
217   void SendGraphAndSuspend(const debugger::GraphProto &graph_proto);
218 
219   void SendMultiGraphsAndSuspend(const std::list<debugger::GraphProto> &graph_proto_list);
220 
221   // send multi_graphs and clear the graph_proto_list_
222   void SendMultiGraphsAndClear(const KernelGraphPtr &graph_ptr);
223 
224   // wait for command and process command
225   // send command request and process reply in a loop
226   // break if RunCMD
227   void CommandLoop();
228 
229   // Process the RunCMD
230   void ProcessRunCMD(const debugger::EventReply &reply);
231   // Process the KSetCMD
232   void ProcessKSetCMD(const debugger::EventReply &reply);
233   // Process the KViewCMD
234   void ProcessKViewCMD(const debugger::EventReply &reply);
235   // ViewCMD base level
236   void ViewBaseLevel(const debugger::EventReply &reply);
237   // ViewCMD statistics level
238   void ViewStatLevel(const debugger::EventReply &reply);
239   // ViewCMD value level
240   void ViewValueLevel(const debugger::EventReply &reply);
241   // set what nodes and conditions to watch
242   void SetWatchpoint(const ProtoVector<debugger::WatchNode> &nodes, const debugger::WatchCondition &condition,
243                      const int32_t id, const ProtoVector<debugger::WatchCondition_Parameter> &parameters);
244 
245   // remove watchpoint with id
246   void RemoveWatchpoint(const int32_t id);
247 
248   // load tensor for view command
249   std::list<debugger::TensorProto> LoadTensors(const ProtoVector<debugger::TensorProto> &tensors) const;
250 
251   // load tensor base for view command
252   std::list<debugger::TensorBase> LoadTensorsBase(const ProtoVector<debugger::TensorProto> &tensors) const;
253 
254   // load tensor statistics for view command
255   std::list<debugger::TensorSummary> LoadTensorsStat(const ProtoVector<debugger::TensorProto> &tensors) const;
256 
257   // terminate training process
258   void Exit(bool exit_success = false);
259 
260   // analyze tensors and check watchpoint conditions
261   // return names of tensors and what condition they hit
262   std::list<debugger::WatchpointHit> CheckWatchpoints(const std::string &watchnode = std::string(),
263                                                       const CNodePtr &kernel = nullptr, bool recheck = false);
264 
265   // send watchpoints that hit
266   void SendWatchpoints(const std::list<debugger::WatchpointHit> &points);
267 
268   // Check if the port is valid
269   bool CheckPort(const std::string &port) const;
270 
271   // Check if the IP is valid
272   bool CheckIp(const std::string &host) const;
273 
274   void LoadSingleAnfnode(const AnfNodePtr &anf_node, const size_t output_index, uint32_t root_graph_id);
275 
276   void LoadSingleParameterMindRT(const AnfNodePtr &anf_node);
277 
278   // class members
279 
280   std::unique_ptr<GrpcClient> grpc_client_;
281   std::unique_ptr<DebugServices> debug_services_;
282   std::unique_ptr<std::thread> heartbeat_thread_;
283   KernelGraphPtr graph_ptr_;
284   uint32_t device_id_;
285   std::string device_target_;
286   int32_t num_step_;
287   bool debugger_enabled_;
288   bool suspended_at_last_kernel_;
289   std::string run_level_;
290   std::string node_name_;
291   std::string cur_name_;
292   bool training_done_;
293   bool send_metadata_done_;
294   bool received_new_graph_;
295   bool is_dataset_graph_;
296   bool partial_memory_;
297   std::mutex access_lock_;
298   uint32_t cur_root_graph_id_ = UINT32_MAX;
299   uint32_t prev_root_graph_id_ = UINT32_MAX;
300   // flag to keep track of the very first suspension of debugger
301   bool initial_suspend_;
302   bool enable_heartbeat_;
303 
304   std::list<debugger::GraphProto> graph_proto_list_;
305   std::list<KernelGraphPtr> graph_ptr_list_;
306   // The vector of all the kernel graph pointers for the root graph that will execute in the current step.
307   std::vector<KernelGraphPtr> graph_ptr_step_vec_;
308   // The set of graph pointers that have been run in the current step.
309   std::set<KernelGraphPtr> executed_graph_ptr_set_;
310   // The vector of all the parameters for the current step for mindRT.
311   std::vector<AnfNodePtr> parameters_mindRT_;
312   std::vector<uint32_t> visited_root_graph_ids_;
313 
314   // map to store iter num in each epoch when dataset_sink_mode is true
315   std::map<uint32_t, int32_t> graph_iter_num_map_;
316 
317   // singleton
318   inline static std::mutex instance_lock_ = {};
319   inline static std::shared_ptr<Debugger> debugger_ = nullptr;
320   uint32_t not_dataset_graph_sum_;
321   std::list<uint32_t> rungraph_id_list_;
322   bool ascend_kernel_by_kernel_;
323   bool enable_debugger_called_;
324   std::string version_;
325 };
326 using DebuggerPtr = std::shared_ptr<Debugger>;
327 }  // namespace mindspore
328 #endif  // MINDSPORE_CCSRC_DEBUG_DEBUGGER_DEBUGGER_H_
329