• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <dirent.h>
18 #include <cstdio>
19 #include <fstream>
20 #include <tuple>
21 #include <vector>
22 #include <algorithm>
23 #include <iostream>
24 #include <cstring>
25 #include <utility>
26 #include <map>
27 #include <regex>
28 #include "debug/debugger/debugger.h"
29 #include "debug/data_dump/dump_json_parser.h"
30 #include "pipeline/jit/pipeline.h"
31 #include "backend/session/anf_runtime_algorithm.h"
32 #include "runtime/device/kernel_runtime_manager.h"
33 #include "runtime/device/kernel_runtime.h"
34 #include "debug/data_dump/e2e_dump.h"
35 #include "utils/config_manager.h"
36 #include "debug/env_config_parser.h"
37 #include "utils/comm_manager.h"
38 #include "runtime/hardware/device_context_manager.h"
39 #include "debug/anf_ir_dump.h"
40 #include "debug/anf_ir_utils.h"
41 #ifdef ENABLE_DEBUGGER
42 #include "debug/debugger/proto_exporter.h"
43 #else
44 #include "debug/debugger/proto_exporter_stub.h"
45 #endif
46 
47 using debugger::Chunk;
48 using debugger::EventReply;
49 using debugger::GraphProto;
50 using debugger::ModelProto;
51 using debugger::Statistics;
52 using debugger::TensorProto;
53 using debugger::WatchCondition;
54 using debugger::WatchCondition_Condition_inf;
55 using debugger::WatchCondition_Condition_nan;
56 using debugger::WatchCondition_Parameter;
57 using debugger::WatchNode;
58 using debugger::WatchpointHit;
59 
60 namespace mindspore {
61 
62 static constexpr auto g_chunk_size = 1024 * 1024 * 3;
63 static constexpr int32_t heartbeat_period_second = 30;
64 DebuggerPtr Debugger::debugger_ = nullptr;
65 std::mutex Debugger::instance_lock_;
66 
Debugger()67 Debugger::Debugger()
68     : grpc_client_(nullptr),
69       debug_services_(nullptr),
70       heartbeat_thread_(nullptr),
71       device_id_(0),
72       device_target_(""),
73       num_step_(0),
74       debugger_enabled_(false),
75       suspended_at_last_kernel_(false),
76       run_level_(""),
77       node_name_(""),
78       cur_name_(""),
79       training_done_(false),
80       is_dataset_graph_(false),
81       partial_memory_(false),
82       initial_suspend_(true),
83       enable_heartbeat_(false),
84       not_dataset_graph_sum_(0),
85       version_("") {
86   CheckDebuggerEnabledParam();
87   auto ms_context = MsContext::GetInstance();
88   MS_EXCEPTION_IF_NULL(ms_context);
89   std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
90   MS_LOG(INFO) << "Debugger got device_target: " << device_target;
91   if (device_target == kCPUDevice) {
92     MS_LOG(WARNING) << "Not enabling debugger. Debugger does not support CPU.";
93   } else if (CheckDebuggerEnabled()) {
94     // configure partial memory reuse
95     partial_memory_ = CheckDebuggerPartialMemoryEnabled();
96 
97     // switch memory reuse on or off
98     EnvConfigParser::GetInstance().SetSysMemreuse(partial_memory_);
99     // print some message about memory reuse to user
100     if (partial_memory_) {
101       MS_LOG(WARNING)
102         << "Partial Memory Reuse is enabled. Note: 1. Please only set watchpoints before running the first "
103            "step. 2. Tensor values are only available for nodes that are watched by any watchpoint.";
104     } else {
105       MS_LOG(WARNING)
106         << "Memory Reuse is disabled. Set environment variable MS_DEBUGGER_PARTIAL_MEM=1 to reduce memory "
107            "usage for large models.";
108     }
109   }
110 }
111 
Init(const uint32_t device_id,const std::string device_target)112 void Debugger::Init(const uint32_t device_id, const std::string device_target) {
113   // access lock for public method
114   std::lock_guard<std::mutex> a_lock(access_lock_);
115   // save device_id
116   MS_LOG(INFO) << "Debugger got device_id: " << device_id;
117   device_id_ = device_id;
118   MS_LOG(INFO) << "Debugger got device_target: " << device_target;
119   device_target_ = device_target;
120   version_ = "1.5.0";
121 }
122 
IsTypeDebuggerSupported(TypeId type)123 bool IsTypeDebuggerSupported(TypeId type) {
124   if (type < TypeId::kNumberTypeEnd && type > TypeId::kNumberTypeBegin && type != kNumberTypeComplex64) {
125     return true;
126   }
127   MS_LOG(INFO) << "Debugger does not support type: " << TypeIdLabel(type);
128   return false;
129 }
130 
EnableDebugger()131 void Debugger::EnableDebugger() {
132   // reset some of the class members
133   num_step_ = 0;
134   debugger_enabled_ = false;
135   enable_heartbeat_ = false;
136   partial_memory_ = false;
137   grpc_client_ = nullptr;
138   debug_services_ = nullptr;
139   heartbeat_thread_ = nullptr;
140 
141   // see if dump using debugger backend is enabled
142   bool dump_enabled = CheckDebuggerDumpEnabled();
143   MS_LOG(INFO) << "dump using debugger backend = " << dump_enabled;
144 
145   // check if debugger enabled
146   debugger_enabled_ = CheckDebuggerEnabled();
147   MS_LOG(INFO) << "debugger_enabled_ = " << debugger_enabled_;
148 
149   if (!debugger_enabled_ && !dump_enabled) {
150     MS_LOG(INFO) << "Not enabling debugger. Set environment variable ENABLE_MS_DEBUGGER=1 to enable debugger.";
151     return;
152   }
153 
154   if (debugger_enabled_) {
155     // configure grpc host
156     std::string env_host_str = common::GetEnv("MS_DEBUGGER_HOST");
157     std::string host;
158     if (!env_host_str.empty()) {
159       if (CheckIp(env_host_str)) {
160         MS_LOG(INFO) << "Getenv MS_DEBUGGER_HOST: " << env_host_str;
161         host = env_host_str;
162       } else {
163         debugger_enabled_ = false;
164         MS_EXCEPTION(ValueError) << "Environment variable MS_DEBUGGER_HOST isn't a valid IP address. "
165                                     "Please set environment variable MS_DEBUGGER_HOST=x.x.x.x to a valid IP";
166       }
167     } else {
168       MS_LOG(INFO) << "Environment variable MS_DEBUGGER_HOST doesn't exist. Using default debugger host: localhost";
169       host = "localhost";
170     }
171     // configure grpc port
172     std::string env_port_str = common::GetEnv("MS_DEBUGGER_PORT");
173     std::string port;
174     if (!env_port_str.empty()) {
175       if (CheckPort(env_port_str)) {
176         MS_LOG(INFO) << "Getenv MS_DEBUGGER_PORT: " << env_port_str;
177         port = env_port_str;
178       } else {
179         debugger_enabled_ = false;
180         MS_EXCEPTION(ValueError) << "Environment variable MS_DEBUGGER_PORT is not valid. Custom port ranging from 1 to "
181                                     "65535";
182       }
183     } else {
184       port = "50051";
185       if (!CheckPort(port)) {
186         MS_EXCEPTION(ValueError) << "Default MS_DEBUGGER_PORT is not valid. Custom port ranging from 1 to 65535";
187       }
188       MS_LOG(INFO) << "Environment variable MS_DEBUGGER_PORT doesn't exist. Using default debugger port: 50051";
189     }
190     // initialize grpc client
191     grpc_client_ = std::make_unique<GrpcClient>(host, port);
192     // initialize sending heartbeat
193     heartbeat_thread_ = std::make_unique<std::thread>([this]() { SendHeartbeat(heartbeat_period_second); });
194   }
195   debug_services_ = std::make_unique<DebugServices>();
196 }
197 
CheckDatasetSinkMode()198 void Debugger::CheckDatasetSinkMode() {
199   if (CheckDebuggerDumpEnabled() && ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
200     MS_EXCEPTION(NotSupportError)
201       << "e2e_dump not supported on GPU with dataset_sink_mode=True. Please set dataset_sink_mode=False";
202   }
203 
204   if (CheckDebuggerEnabled() && ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
205     MS_EXCEPTION(NotSupportError)
206       << "Debugger is not supported with dataset_sink_mode=True. Please set dataset_sink_mode=False";
207   }
208 }
209 
CheckDebuggerDumpEnabled() const210 bool Debugger::CheckDebuggerDumpEnabled() const {
211   // see if dump is enabled
212   if (device_target_ == kGPUDevice) {
213     return device::KernelRuntime::DumpDataEnabled();
214   }
215   return false;
216 }
217 
CheckDebuggerEnabled() const218 bool Debugger::CheckDebuggerEnabled() const {
219   // get env variables to configure debugger
220   std::string env_enable_str = common::GetEnv("ENABLE_MS_DEBUGGER");
221   if (!env_enable_str.empty()) {
222     (void)std::transform(env_enable_str.begin(), env_enable_str.end(), env_enable_str.begin(), ::tolower);
223     if ((env_enable_str == "1" || env_enable_str == "true") && device_target_ != kCPUDevice) {
224       return true;
225     }
226   }
227   return false;
228 }
229 
CheckDebuggerEnabledParam() const230 void Debugger::CheckDebuggerEnabledParam() const {
231   // check the value of env variable ENABLE_MS_DEBUGGER
232   std::string env_enable_str = common::GetEnv("ENABLE_MS_DEBUGGER");
233   if (!env_enable_str.empty()) {
234     (void)std::transform(env_enable_str.begin(), env_enable_str.end(), env_enable_str.begin(), ::tolower);
235     if (env_enable_str != "0" && env_enable_str != "1" && env_enable_str != "false" && env_enable_str != "true") {
236       MS_LOG(WARNING) << "Env variable ENABLE_MS_DEBUGGER should be True/False/1/0 (case insensitive), but get: "
237                       << env_enable_str;
238     }
239   }
240 }
241 
CheckDebuggerPartialMemoryEnabled() const242 bool Debugger::CheckDebuggerPartialMemoryEnabled() const {
243   std::string env_partial_mem_str = common::GetEnv("MS_DEBUGGER_PARTIAL_MEM");
244   if (!env_partial_mem_str.empty()) {
245     MS_LOG(INFO) << "Getenv MS_DEBUGGER_PARTIAL_MEM: " << env_partial_mem_str;
246     if (env_partial_mem_str == "1") {
247       return true;
248     }
249   }
250   return false;
251 }
252 
DebuggerBackendEnabled() const253 bool Debugger::DebuggerBackendEnabled() const { return CheckDebuggerDumpEnabled() || CheckDebuggerEnabled(); }
254 
Reset()255 void Debugger::Reset() {
256   // access lock for public method
257   std::lock_guard<std::mutex> a_lock(access_lock_);
258   // reset components
259   if (heartbeat_thread_ && heartbeat_thread_->joinable()) {
260     SetEnableHeartbeat(false);
261     heartbeat_thread_->join();
262     MS_LOG(INFO) << "Join Heartbeat thread.";
263   }
264   heartbeat_thread_ = nullptr;
265   device_id_ = 0;
266   device_target_ = "";
267   num_step_ = 0;
268   debugger_enabled_ = false;
269   is_dataset_graph_ = false;
270   partial_memory_ = false;
271   graph_ptr_ = nullptr;
272   grpc_client_ = nullptr;
273   debug_services_ = nullptr;
274   graph_proto_list_.clear();
275   graph_ptr_list_.clear();
276   graph_ptr_step_vec_.clear();
277   MS_LOG(INFO) << "Release Debugger resource.";
278 }
279 
PreExecuteGraphDebugger(const std::vector<KernelGraphPtr> & graphs)280 void Debugger::PreExecuteGraphDebugger(const std::vector<KernelGraphPtr> &graphs) {
281   // Only GPU is supported for MindRTBackend
282   if (device_target_ != kGPUDevice) {
283     return;
284   }
285   E2eDump::UpdateIterGPUDump();
286   // Store graphs that are run in one step.
287   graph_ptr_step_vec_ = graphs;
288   for (size_t graph_index = 0; graph_index < graphs.size(); ++graph_index) {
289     const auto &graph = graphs[graph_index];
290     if (debugger_) {
291       debugger_->PreExecute(graph);
292     }
293     DumpSetup(graph);
294   }
295 }
296 
PreExecute(const KernelGraphPtr & graph_ptr)297 void Debugger::PreExecute(const KernelGraphPtr &graph_ptr) {
298   MS_EXCEPTION_IF_NULL(graph_ptr);
299   // access lock for public method
300   std::lock_guard<std::mutex> a_lock(access_lock_);
301   CheckDatasetSinkMode();
302   auto graph_id = graph_ptr->graph_id();
303   // collect rungrap_ids to update step number in multigraph case
304   if (!rungraph_id_list_.size()) {
305     rungraph_id_list_.push_back(graph_id);
306 
307   } else {
308     if (std::find(rungraph_id_list_.begin(), rungraph_id_list_.end(), graph_id) == rungraph_id_list_.end()) {
309       rungraph_id_list_.push_back(graph_id);
310     }
311   }
312   // multiple graphs
313   if (graph_proto_list_.size() > 1) {
314     // there are more than one graphs are not dataset_graph
315     if (not_dataset_graph_sum_ > 0) {
316       SendMultiGraphsAndClear(graph_ptr);
317     }
318   } else if (graph_proto_list_.size() == 1) {
319     // single graph, and not the initial step
320     if (device_target_ == kGPUDevice && num_step_ != 0) {
321       if (debugger_enabled_ && !(run_level_ == "node" && suspended_at_last_kernel_)) {
322         CommandLoop();
323       }
324       debug_services_->ResetLoadedTensors();
325     }
326     // In single graph case, reset graph_ptr_ to be nullptr for the initial step
327     if (num_step_ == 0) {
328       graph_ptr_ = nullptr;
329       CheckGraphPtr(graph_ptr);
330     }
331   } else if (debugger_enabled_ && graph_id == rungraph_id_list_.front() && device_target_ == kGPUDevice) {
332     // Multiple graph, and not the initial step,
333     // stop only when receive the first sub run graph for each step
334     // if we have stopped for the last kernel before, no need to stop again
335     if (pipeline::GraphExecutorPy::GetDebugTerminate()) {
336       return;
337     }
338     if (!(run_level_ == "node" && suspended_at_last_kernel_)) {
339       CommandLoop();
340     }
341     debug_services_->ResetLoadedTensors();
342   }
343   // resets for the new graph
344   suspended_at_last_kernel_ = false;
345 }
346 
SendMultiGraphsAndClear(const KernelGraphPtr & graph_ptr)347 void Debugger::SendMultiGraphsAndClear(const KernelGraphPtr &graph_ptr) {
348   // only try to enable debugger if they are not all dataset graphs
349   if (!debugger_enabled_) {
350     EnableDebugger();
351   }
352   if (debugger_enabled_) {
353     // only send compiled graphs once at the initial step.
354     auto dbg_graph_ptr = graph_ptr_;
355     // use current graph ptr to load parameters
356     graph_ptr_ = graph_ptr;
357     LoadParametersAndConst();
358     // revert graph ptr to original value
359     graph_ptr_ = dbg_graph_ptr;
360 
361     SendMultiGraphsAndSuspend(graph_proto_list_);
362 
363     graph_proto_list_.clear();
364   }
365 }
366 
DumpDataEnabledIteration() const367 bool Debugger::DumpDataEnabledIteration() const {
368   auto &dump_json_parser = DumpJsonParser::GetInstance();
369   if (!dump_json_parser.e2e_dump_enabled()) {
370     return false;
371   }
372 
373   auto cur_iter = dump_json_parser.cur_dump_iter();
374   if (dump_json_parser.IsDumpIter(cur_iter)) {
375     return true;
376   }
377   return false;
378 }
379 
GetRankID()380 uint32_t Debugger::GetRankID() {
381   auto ms_context = MsContext::GetInstance();
382   MS_EXCEPTION_IF_NULL(ms_context);
383   std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
384   uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
385   const auto &device_context =
386     device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_target, device_id});
387   uint32_t rank_id = device_context->GetRankID();
388   return rank_id;
389 }
390 
Dump(const KernelGraphPtr & kernel_graph) const391 void Debugger::Dump(const KernelGraphPtr &kernel_graph) const {
392   uint32_t rank_id = GetRankID();
393   if (debugger_ && debugger_->DebuggerBackendEnabled()) {
394     MS_EXCEPTION_IF_NULL(kernel_graph);
395     (void)E2eDump::DumpParametersAndConstData(kernel_graph.get(), rank_id, debugger_.get());
396   } else {
397     DumpJsonParser::GetInstance().UpdateDumpIter();
398   }
399 }
400 
DumpSingleNode(const CNodePtr & node,uint32_t graph_id)401 void Debugger::DumpSingleNode(const CNodePtr &node, uint32_t graph_id) {
402   if (debugger_ && debugger_->DebuggerBackendEnabled()) {
403     uint32_t rank_id = GetRankID();
404     (void)E2eDump::DumpSingleNodeData(node, graph_id, rank_id, debugger_.get());
405   }
406 }
407 
DumpSetup(const KernelGraphPtr & kernel_graph) const408 void Debugger::DumpSetup(const KernelGraphPtr &kernel_graph) const {
409   MS_LOG(INFO) << "Start!";
410   MS_EXCEPTION_IF_NULL(kernel_graph);
411   E2eDump::DumpSetup(kernel_graph.get());
412   MS_LOG(INFO) << "Finish!";
413 }
414 
DumpInGraphCompiler(const KernelGraphPtr & kernel_graph)415 void Debugger::DumpInGraphCompiler(const KernelGraphPtr &kernel_graph) {
416   // This function will be called for new GPU runtime using MindRTBackend
417   auto &json_parser = DumpJsonParser::GetInstance();
418   if (json_parser.e2e_dump_enabled()) {
419     uint32_t rank_id = GetRankID();
420     kernel_graph->set_root_graph_id(kernel_graph->graph_id());
421     std::string final_graph = "trace_code_graph_" + std::to_string(kernel_graph->graph_id());
422     std::string root_dir = json_parser.path() + "/rank_" + std::to_string(rank_id);
423     std::string target_dir = root_dir + "/graphs";
424     std::string ir_file_path = target_dir + "/" + "ms_output_" + final_graph + ".ir";
425     DumpIRProtoWithSrcInfo(kernel_graph, final_graph, target_dir, kDebugWholeStack);
426     DumpIR("trace_code_graph", kernel_graph, true, kWholeStack, ir_file_path);
427     DumpGraphExeOrder("ms_execution_order_graph_" + std::to_string(kernel_graph->graph_id()) + ".csv", root_dir,
428                       kernel_graph->execution_order());
429   }
430 }
431 
PostExecuteGraphDebugger()432 void Debugger::PostExecuteGraphDebugger() {
433   // On CPU, update dump iteration, Parameters and consts are not dumped here
434   if (device_target_ == kCPUDevice) {
435     DumpJsonParser::GetInstance().UpdateDumpIter();
436     return;
437   }
438   // Only GPU is supported for MindRTBackend
439   if (device_target_ != kGPUDevice) {
440     return;
441   }
442   // LoadParametersAndConst for all the graphs that have been run in the current step
443   if (debugger_) {
444     for (auto graph : graph_ptr_step_vec_) {
445       debugger_->LoadParametersAndConst(graph);
446     }
447   }
448   // debug used for dump
449   if (debugger_ && debugger_->CheckDebuggerDumpEnabled()) {
450     // Dump Parameters and consts
451     for (auto graph : graph_ptr_step_vec_) {
452       debugger_->Dump(graph);
453       if (!debugger_->debugger_enabled()) {
454         debugger_->ClearCurrentData();
455       }
456     }
457   }
458   if (debugger_) {
459     debugger_->PostExecute();
460   }
461 }
462 
PostExecute()463 void Debugger::PostExecute() {
464   // access lock for public method
465   std::lock_guard<std::mutex> a_lock(access_lock_);
466   if (pipeline::GraphExecutorPy::GetDebugTerminate()) {
467     return;
468   }
469   if (debugger_ && debugger_->DebuggerBackendEnabled()) {
470     // analyze tensor data and send the watchpoints been hit
471     if (debugger_enabled_ && !is_dataset_graph_) {
472       if (device_target_ != kGPUDevice) {
473         num_step_++;
474       }
475       SendWatchpoints(CheckWatchpoints());
476 
477       // no need to suspend at each graph for GPU, suspension happens in preExecute
478       if (device_target_ != kGPUDevice) {
479         CommandLoop();
480       }
481     }
482     // Only keep parameters in the current map
483     // GPU ResetLoadedTensors happens in preExecute
484     if (device_target_ != kGPUDevice) {
485       debug_services_->ResetLoadedTensors();
486     }
487   }
488 }
489 
ReadNodeDataRequired(const CNodePtr & kernel) const490 bool Debugger::ReadNodeDataRequired(const CNodePtr &kernel) const {
491   if (debugger_enabled_ && !is_dataset_graph_) {
492     auto is_watchpoint = debug_services_->IsWatchPoint(cur_name_, kernel);
493     // if node has a watchpoint on it, is next_to node, or continue_to node then read the kernel tensor data
494     if (is_watchpoint || (run_level_ == "node" && (node_name_ == "" || node_name_ == cur_name_))) {
495       return true;
496     }
497   }
498   return false;
499 }
500 
PostExecuteNode(const CNodePtr & kernel,bool last_kernel)501 void Debugger::PostExecuteNode(const CNodePtr &kernel, bool last_kernel) {
502   // access lock for public method
503   std::lock_guard<std::mutex> a_lock(access_lock_);
504   if (pipeline::GraphExecutorPy::GetDebugTerminate()) {
505     return;
506   }
507   if (debugger_enabled_ && !is_dataset_graph_) {
508     auto is_watchpoint = debug_services_->IsWatchPoint(cur_name_, kernel);
509 
510     // if kernel is watchpoint,and get hit. suspend.
511     bool hit_empty_flag = true;
512     if (is_watchpoint) {
513       auto hits = CheckWatchpoints(cur_name_, kernel);
514       if (!hits.empty()) {
515         SendWatchpoints(hits);
516         CommandLoop();
517 
518         hit_empty_flag = false;
519       }
520     }
521     if (hit_empty_flag && run_level_ == "node" && (node_name_ == "" || node_name_ == cur_name_)) {
522       // if kernel is not watchpoint and is next_to or continue_to node, suspend
523       // sets a bool to be checked in preExecute to avoid double stopping at last kernel in the last graph
524       if (last_kernel) {
525         suspended_at_last_kernel_ = true;
526       }
527       CommandLoop();
528     }
529     return;
530   }
531 }
532 
LoadGraphs(const KernelGraphPtr & graph_ptr)533 void Debugger::LoadGraphs(const KernelGraphPtr &graph_ptr) {
534   MS_EXCEPTION_IF_NULL(graph_ptr);
535   if (graph_ptr_ != graph_ptr) {
536     MS_LOG(INFO) << "LoadGraphs Debugger got new graph: " << graph_ptr->graph_id();
537     // save new graph_ptr
538     graph_ptr_ = graph_ptr;
539     CheckDatasetGraph();
540     if (!is_dataset_graph_) {
541       // get proto for new graph_ptr
542       auto graph_proto = GetGraphProto(graph_ptr);
543       // add new graph proto to graph_proto_list_
544       graph_proto_list_.push_back(graph_proto);
545       graph_ptr_list_.push_back(graph_ptr);
546       not_dataset_graph_sum_++;
547     }
548     // reset is_dataset_graph to be false
549     is_dataset_graph_ = false;
550   }
551 }
552 
553 // In single graph cases, check single graph ptr
CheckGraphPtr(const KernelGraphPtr & graph_ptr)554 void Debugger::CheckGraphPtr(const KernelGraphPtr &graph_ptr) {
555   MS_EXCEPTION_IF_NULL(graph_ptr);
556   if (graph_ptr_ != graph_ptr) {
557     MS_LOG(INFO) << "CheckGraphPtr Debugger got new graph: " << graph_ptr->graph_id();
558     // save new graph_ptr
559     graph_ptr_ = graph_ptr;
560     if (!is_dataset_graph_) {
561       // only try to enable debugger if it is not a dataset graph
562       EnableDebugger();
563       if (debugger_enabled_) {
564         LoadParametersAndConst();
565         // get graph proto and send to Mindinsight
566         auto graph_proto = graph_proto_list_.front();
567         SendGraphAndSuspend(graph_proto);
568       }
569     }
570   }
571 }
572 
CheckDatasetGraph()573 void Debugger::CheckDatasetGraph() {
574   // print parameter node names
575   MS_EXCEPTION_IF_NULL(graph_ptr_);
576   const auto &params = graph_ptr_->inputs();
577   for (const auto &param : params) {
578     MS_LOG(INFO) << "param: " << GetKernelNodeName(param);
579   }
580   // check if there is GetNext or InitDataSetQueue node
581   const auto &nodes = graph_ptr_->execution_order();
582   for (const auto &node : nodes) {
583     auto node_name = AnfAlgo::GetCNodeName(node);
584     MS_LOG(INFO) << "node: " << GetKernelNodeName(node);
585     if (node_name == "GetNext" || node_name == "InitDataSetQueue") {
586       MS_LOG(INFO) << "Not enabling debugger for graph " << graph_ptr_->graph_id() << ": found dataset graph node "
587                    << node_name;
588       is_dataset_graph_ = true;
589       return;
590     }
591   }
592   is_dataset_graph_ = false;
593 }
594 
GetGraphProto(const KernelGraphPtr & graph_ptr) const595 GraphProto Debugger::GetGraphProto(const KernelGraphPtr &graph_ptr) const {
596   // convert kernel graph to debugger modelproto
597   ModelProto model = GetDebuggerFuncGraphProto(graph_ptr);
598   return model.graph();
599 }
600 
SendHeartbeat(int32_t period)601 void Debugger::SendHeartbeat(int32_t period) {
602   int num_heartbeat_fail = 0;
603   const int max_num_heartbeat_fail = 5;
604   const int retry_milliseconds = 500;
605 
606   Heartbeat heartbeat;
607   heartbeat.set_message("Debugger is alive");
608   heartbeat.set_period(heartbeat_period_second);
609 
610   SetEnableHeartbeat(CheckDebuggerEnabled());
611   while (enable_heartbeat_) {
612     MS_EXCEPTION_IF_NULL(grpc_client_);
613     EventReply reply = grpc_client_->SendHeartbeat(heartbeat);
614     if (reply.status() != reply.OK) {
615       MS_LOG(ERROR) << "Error: SendHeartbeat failed";
616       num_heartbeat_fail++;
617       if (num_heartbeat_fail >= max_num_heartbeat_fail) {
618         MS_LOG(ERROR) << "Maximum number of failure for SendHeartbeat reached : exiting training session.";
619         SetEnableHeartbeat(false);
620         break;
621       } else {
622         MS_LOG(ERROR) << "Number of consecutive SendHeartbeat fail:" << num_heartbeat_fail;
623         std::this_thread::sleep_for(std::chrono::milliseconds(retry_milliseconds));
624       }
625     } else {
626       std::this_thread::sleep_for(std::chrono::milliseconds(period * 1000));
627     }
628   }
629 }
630 
SendGraphAndSuspend(const GraphProto & graph_proto)631 void Debugger::SendGraphAndSuspend(const GraphProto &graph_proto) {
632   if (SendMetadata(true)) {
633     // send graph to Mindinsight server
634     MS_EXCEPTION_IF_NULL(grpc_client_);
635     EventReply reply = grpc_client_->SendGraph(graph_proto);
636     if (reply.status() != reply.OK) {
637       MS_LOG(ERROR) << "Error: SendGraph failed";
638     }
639     // enter command loop, wait and process commands
640     CommandLoop();
641   }
642 }
643 
SendMetadata(bool version_check)644 bool Debugger::SendMetadata(bool version_check) {
645   // prepare metadata
646   MS_EXCEPTION_IF_NULL(graph_ptr_);
647   std::string device_name = std::to_string(device_id_) + ":" + std::to_string(graph_ptr_->graph_id());
648   Metadata metadata;
649   metadata.set_device_name(device_name);
650   metadata.set_cur_step(num_step_);
651   metadata.set_backend(device_target_);
652   metadata.set_cur_node(cur_name_);
653   metadata.set_training_done(training_done_);
654   metadata.set_ms_version(version_);
655   MS_LOG(INFO) << "Is training done?" << training_done_;
656   // set graph number to not_dataset_graph_sum_
657   metadata.set_graph_num(not_dataset_graph_sum_);
658 
659   MS_EXCEPTION_IF_NULL(grpc_client_);
660   EventReply reply_metadata = grpc_client_->SendMetadata(metadata);
661 
662   bool ret = false;
663   if (reply_metadata.status() == reply_metadata.OK) {
664     if (version_check) {
665       // get type of the command in meta data reply, it should be version matched
666       DebuggerCommand cmd = GetCommand(reply_metadata);
667       if (cmd != DebuggerCommand::kVersionMatchedCMD) {
668         MS_LOG(ERROR) << "MindInsight version is too old, Mindspore version is " << version_;
669         Exit();
670       } else {
671         if (GetMiVersionMatched(reply_metadata)) {
672           MS_LOG(INFO) << "MindSpore version is " << version_ << " matches MindInsight version.";
673           ret = true;
674         } else {
675           MS_LOG(ERROR) << "MindSpore version " << version_ << ", did not match MindInsight version.";
676           CommandLoop();
677         }
678       }
679     } else {
680       // version check is done before so we can just return true here
681       ret = true;
682     }
683   } else {
684     MS_LOG(ERROR) << "Error: SendMetadata failed";
685   }
686 
687   return ret;
688 }
689 
SendMultiGraphsAndSuspend(const std::list<GraphProto> & graph_proto_list)690 void Debugger::SendMultiGraphsAndSuspend(const std::list<GraphProto> &graph_proto_list) {
691   if (!SendMetadata(true)) {
692     return;
693   }
694   MS_EXCEPTION_IF_NULL(grpc_client_);
695   // send multiple graphs to mindinght server
696   // split graph into chunks if one graph is larger than chunk size
697   std::list<Chunk> chunked_graph_proto_list;
698   Chunk chunk;
699   for (auto graph : graph_proto_list) {
700     std::string str = graph.SerializeAsString();
701     auto graph_size = graph.ByteSize();
702     if (graph_size > g_chunk_size) {
703       auto sub_graph_str = grpc_client_->ChunkString(str, graph_size);
704 
705       for (unsigned int i = 0; i < sub_graph_str.size(); i++) {
706         chunk.set_buffer(sub_graph_str[i]);
707         if (i < sub_graph_str.size() - 1) {
708           chunk.set_finished(false);
709         } else {
710           chunk.set_finished(true);
711         }
712         chunked_graph_proto_list.push_back(chunk);
713       }
714     } else {
715       chunk.set_buffer(str);
716       chunk.set_finished(true);
717       chunked_graph_proto_list.push_back(chunk);
718     }
719   }
720   EventReply reply = grpc_client_->SendMultiGraphs(chunked_graph_proto_list);
721   if (reply.status() != reply.OK) {
722     MS_LOG(ERROR) << "Error: SendGraph failed";
723   }
724   // enter command loop, wait and process commands
725   CommandLoop();
726 }
727 
CommandLoop()728 void Debugger::CommandLoop() {
729   // prepare metadata
730   MS_EXCEPTION_IF_NULL(graph_ptr_);
731   std::string device_name = std::to_string(device_id_) + ":" + std::to_string(graph_ptr_->graph_id());
732   Metadata metadata;
733 
734   metadata.set_device_name(device_name);
735   metadata.set_cur_step(num_step_);
736   metadata.set_backend(device_target_);
737   metadata.set_cur_node(cur_name_);
738   metadata.set_training_done(training_done_);
739 
740   // loop exit flag
741   bool run = false;
742   int num_wait_fail = 0;
743   const int max_num_wait_fail = 5;
744 
745   while (!run) {
746     // wait for command
747     MS_EXCEPTION_IF_NULL(grpc_client_);
748     EventReply reply = grpc_client_->WaitForCommand(metadata);
749     if (reply.status() != reply.OK) {
750       MS_LOG(ERROR) << "Error: WaitForCommand failed";
751       num_wait_fail++;
752       if (num_wait_fail > max_num_wait_fail) {
753         MS_LOG(ERROR) << "Maximum number of WaitForCommand retry reached: exiting training session.";
754         MS_LOG(ERROR) << "Failed to connect to MindInsight debugger server. Please check the config "
755                          "of debugger host and port.";
756         Exit();
757         run = true;
758       } else {
759         MS_LOG(ERROR) << "Number of consecutive WaitForCommand fail:" << num_wait_fail << "; Retry after "
760                       << num_wait_fail << "s";
761         std::this_thread::sleep_for(std::chrono::seconds(num_wait_fail));
762       }
763       continue;
764     }
765 
766     // get type of the command in reply
767     DebuggerCommand cmd = GetCommand(reply);
768     if (cmd == DebuggerCommand::kUnknownCMD) {
769       MS_LOG(DEBUG) << "Debug: debugger received unknown command";
770       continue;
771     }
772 
773     MS_LOG(INFO) << "received command: ";
774     switch (cmd) {
775       case DebuggerCommand::kUnknownCMD:
776         MS_LOG(INFO) << "UnknownCMD";
777         break;
778       case DebuggerCommand::kExitCMD:
779         MS_LOG(INFO) << "ExitCMD";
780         Exit(true);
781         // Used for debugger termination
782         run = true;
783         break;
784       case DebuggerCommand::kRunCMD:
785         ProcessRunCMD(reply);
786         if (GetRunLevel(reply) != "recheck") {
787           // exit loop
788           run = true;
789         }
790         break;
791       case DebuggerCommand::kSetCMD:
792         ProcessKSetCMD(reply);
793         break;
794       case DebuggerCommand::kViewCMD:
795         ProcessKViewCMD(reply);
796         break;
797       case DebuggerCommand::kVersionMatchedCMD:
798         MS_LOG(ERROR) << "Received unexpected Version Matched CMD from Mindinsight.";
799         Exit();
800         break;
801       default:
802         MS_LOG(ERROR) << "Received unknown CMD from Mindinsight";
803         Exit();
804         break;
805     }
806   }
807 }
808 
ProcessRunCMD(const EventReply & reply)809 void Debugger::ProcessRunCMD(const EventReply &reply) {
810   MS_LOG(INFO) << "RunCMD";
811   if (GetRunLevel(reply) == "recheck") {
812     MS_LOG(INFO) << "rechecking all watchpoints";
813     SendWatchpoints(CheckWatchpoints("", nullptr, true));
814   } else {
815     // no longer the initial suspension.
816     initial_suspend_ = false;
817     // print run cmd content
818     // get run_level and node_name
819     run_level_ = GetRunLevel(reply);
820     node_name_ = GetNodeName(reply);
821 
822     MS_LOG(INFO) << "run_level: " << run_level_;
823     MS_LOG(INFO) << "node_name_: " << node_name_;
824   }
825 }
826 
ProcessKSetCMD(const EventReply & reply)827 void Debugger::ProcessKSetCMD(const EventReply &reply) {
828   MS_LOG(INFO) << "SetCMD";
829   MS_LOG(INFO) << "id: " << GetWatchpointID(reply);
830   MS_LOG(INFO) << "delete: " << GetWatchpointDelete(reply);
831   if (GetWatchpointDelete(reply)) {
832     MS_LOG(INFO) << "Deleting watchpoint";
833     RemoveWatchpoint(GetWatchpointID(reply));
834   } else {
835     MS_LOG(INFO) << "Setting watchpoint";
836     MS_LOG(INFO) << "condition: " << GetWatchcondition(reply).condition();
837     ProtoVector<WatchNode> recieved_nodes = GetWatchnodes(reply);
838     for (const auto &node : recieved_nodes) {
839       MS_LOG(INFO) << "node name: " << node.node_name();
840       MS_LOG(INFO) << "node type: " << node.node_type();
841     }
842     ProtoVector<WatchCondition_Parameter> parameters = GetParameters(reply);
843     for (const auto &parameter : parameters) {
844       MS_LOG(INFO) << "parameter name: " << parameter.name();
845       MS_LOG(INFO) << "parameter is disabled: " << parameter.disabled();
846       MS_LOG(INFO) << "parameter value: " << parameter.value();
847     }
848     SetWatchpoint(GetWatchnodes(reply), GetWatchcondition(reply), GetWatchpointID(reply), GetParameters(reply));
849   }
850 }
851 
ProcessKViewCMD(const EventReply & reply)852 void Debugger::ProcessKViewCMD(const EventReply &reply) {
853   MS_LOG(INFO) << "ViewCMD";
854   // print view cmd content
855   ProtoVector<TensorProto> received_tensors = GetTensors(reply);
856   for (auto received_tensor : received_tensors) {
857     MS_LOG(INFO) << "tensor node name: " << received_tensor.node_name();
858     MS_LOG(INFO) << "tensor slot: " << received_tensor.slot();
859     MS_LOG(INFO) << "tensor finished: " << std::boolalpha << received_tensor.finished() << std::noboolalpha;
860     MS_LOG(INFO) << "tensor iter: " << received_tensor.iter();
861     MS_LOG(INFO) << "tensor truncate: " << std::boolalpha << received_tensor.truncate() << std::noboolalpha;
862   }
863 
864   switch (reply.view_cmd().level()) {
865     case debugger::ViewCMD_Level::ViewCMD_Level_base:
866       MS_LOG(INFO) << "Tensor base request.";
867       ViewBaseLevel(reply);
868       break;
869 
870     case debugger::ViewCMD_Level::ViewCMD_Level_statistics:
871       MS_LOG(INFO) << "Tensor statistics request.";
872       ViewStatLevel(reply);
873       break;
874 
875     case debugger::ViewCMD_Level::ViewCMD_Level_value:
876       MS_LOG(INFO) << "Tensor value request.";
877       ViewValueLevel(reply);
878       break;
879     default:
880       MS_LOG(DEBUG) << "Debug: Unknown tensor info level";
881       break;
882   }
883 }
884 
ViewValueLevel(const EventReply & reply)885 void Debugger::ViewValueLevel(const EventReply &reply) {
886   MS_LOG(INFO) << "Sending tensors";
887   std::list<TensorProto> tensors = LoadTensors(GetTensors(reply));
888   // print view cmd reply
889   for (auto tensor : tensors) {
890     MS_LOG(INFO) << "tensor node name: " << tensor.node_name();
891     MS_LOG(INFO) << "tensor slot: " << tensor.slot();
892     MS_LOG(INFO) << "tensor finished: " << std::boolalpha << tensor.finished() << std::noboolalpha;
893     MS_LOG(INFO) << "tensor iter: " << tensor.iter();
894     MS_LOG(INFO) << "tensor truncate: " << std::boolalpha << tensor.truncate() << std::noboolalpha;
895     MS_LOG(INFO) << "tensor dims: ";
896     for (auto dim : tensor.dims()) {
897       MS_LOG(INFO) << dim << ",";
898     }
899     MS_LOG(INFO) << "tensor dtype: " << tensor.data_type();
900   }
901   MS_EXCEPTION_IF_NULL(grpc_client_);
902   EventReply send_tensors_reply = grpc_client_->SendTensors(tensors);
903   if (send_tensors_reply.status() != debugger::EventReply::OK) {
904     MS_LOG(ERROR) << "Error: SendTensors failed";
905   }
906 }
907 
ViewStatLevel(const EventReply & reply)908 void Debugger::ViewStatLevel(const EventReply &reply) {
909   std::list<TensorSummary> tensor_stats_list = LoadTensorsStat(GetTensors(reply));
910   EventReply send_tensors_stat_reply = grpc_client_->SendTensorStats(tensor_stats_list);
911   if (send_tensors_stat_reply.status() != debugger::EventReply::OK) {
912     MS_LOG(ERROR) << "Error: SendTensorsStats failed.";
913   }
914 }
915 
ViewBaseLevel(const EventReply & reply)916 void Debugger::ViewBaseLevel(const EventReply &reply) {
917   std::list<TensorBase> tensor_base_list = LoadTensorsBase(GetTensors(reply));
918   EventReply send_tensor_base_reply = grpc_client_->SendTensorBase(tensor_base_list);
919   if (send_tensor_base_reply.status() != debugger::EventReply::OK) {
920     MS_LOG(ERROR) << "Error: SendTensorsBase failed.";
921   }
922 }
923 
AddTensorProtoInfo(TensorProto * tensor_item,const TensorProto & tensor)924 void AddTensorProtoInfo(TensorProto *tensor_item, const TensorProto &tensor) {
925   tensor_item->set_node_name(tensor.node_name());
926   tensor_item->set_slot(tensor.slot());
927   tensor_item->set_iter(tensor.iter());
928   tensor_item->set_truncate(tensor.truncate());
929   tensor_item->clear_tensor_content();
930   tensor_item->clear_data_type();
931   tensor_item->clear_dims();
932 }
933 
AddTensorStatInfo(const DebugServices::TensorStat & tensor_stat,std::list<TensorSummary> * const tensor_summary_list)934 void AddTensorStatInfo(const DebugServices::TensorStat &tensor_stat,
935                        std::list<TensorSummary> *const tensor_summary_list) {
936   if (tensor_summary_list == nullptr) {
937     MS_LOG(DEBUG) << "tensor_summary_list is nullptr.";
938     return;
939   }
940   TensorSummary tensor_summary_item;
941   TensorBase *tensor_base = tensor_summary_item.mutable_tensor_base();
942   tensor_base->set_data_type(tensor_stat.dtype);
943   tensor_base->set_data_size((int64_t)tensor_stat.data_size);
944   for (auto elem : tensor_stat.shape) {
945     tensor_base->add_shape(elem);
946   }
947 
948   Statistics *tensor_statistics = tensor_summary_item.mutable_statistics();
949   tensor_statistics->set_is_bool(tensor_stat.is_bool);
950   tensor_statistics->set_max_value(static_cast<float>(tensor_stat.max_value));
951   tensor_statistics->set_min_value(static_cast<float>(tensor_stat.min_value));
952   tensor_statistics->set_avg_value(static_cast<float>(tensor_stat.avg_value));
953   tensor_statistics->set_count(tensor_stat.count);
954   tensor_statistics->set_neg_zero_count(tensor_stat.neg_zero_count);
955   tensor_statistics->set_pos_zero_count(tensor_stat.pos_zero_count);
956   tensor_statistics->set_nan_count(tensor_stat.nan_count);
957   tensor_statistics->set_neg_inf_count(tensor_stat.neg_inf_count);
958   tensor_statistics->set_pos_inf_count(tensor_stat.pos_inf_count);
959   tensor_statistics->set_zero_count(tensor_stat.zero_count);
960 
961   tensor_summary_list->push_back(tensor_summary_item);
962 }
963 
SetWatchpoint(const ProtoVector<WatchNode> & nodes,const WatchCondition & condition,const int32_t id,const ProtoVector<WatchCondition_Parameter> & parameters)964 void Debugger::SetWatchpoint(const ProtoVector<WatchNode> &nodes, const WatchCondition &condition, const int32_t id,
965                              const ProtoVector<WatchCondition_Parameter> &parameters) {
966   std::vector<std::tuple<std::string, bool>> check_node_list;
967   std::vector<DebugServices::parameter_t> parameter_list;
968 
969   std::transform(nodes.begin(), nodes.end(), std::back_inserter(check_node_list),
970                  [](const WatchNode &node) -> std::tuple<std::string, bool> {
971                    return make_tuple(node.node_name(), node.node_type() == "scope");
972                  });
973 
974   std::transform(
975     parameters.begin(), parameters.end(), std::back_inserter(parameter_list),
976     [](const WatchCondition_Parameter &parameter) -> DebugServices::parameter_t {
977       return DebugServices::parameter_t{parameter.name(), parameter.disabled(), parameter.value(), parameter.hit()};
978     });
979   debug_services_->AddWatchpoint(id, condition.condition(), condition.value(), check_node_list, parameter_list);
980 }
981 
RemoveWatchpoint(const int32_t id)982 void Debugger::RemoveWatchpoint(const int32_t id) { debug_services_->RemoveWatchpoint(id); }
983 
LoadTensors(const ProtoVector<TensorProto> & tensors) const984 std::list<TensorProto> Debugger::LoadTensors(const ProtoVector<TensorProto> &tensors) const {
985   std::vector<std::string> name;
986   std::vector<std::string> ret_name;
987   std::vector<const char *> data_ptr;
988   std::vector<ssize_t> data_size;
989   std::vector<unsigned int> dtype;
990   std::vector<std::vector<int64_t>> shape;
991 
992   std::transform(tensors.begin(), tensors.end(), std::back_inserter(name), GetTensorFullName);
993 
994   // ret_name will contain tensor names that are found in TensorLoader
995   // items in ret_name will be in the same order with tensors if found
996   debug_services_->ReadNodesTensors(name, &ret_name, &data_ptr, &data_size, &dtype, &shape);
997   std::list<TensorProto> tensor_list;
998   size_t result_index = 0;
999 
1000   for (auto tensor : tensors) {
1001     ssize_t size_iter = 0;
1002     if (result_index >= ret_name.size() || ret_name[result_index] != GetTensorFullName(tensor)) {
1003       TensorProto tensor_item;
1004       tensor_item.set_finished(true);
1005       AddTensorProtoInfo(&tensor_item, tensor);
1006       tensor_list.push_back(tensor_item);
1007       continue;
1008     }
1009     ssize_t tensor_size = data_size[result_index];
1010     while (size_iter < tensor_size) {
1011       ssize_t chunk_size = g_chunk_size;
1012       TensorProto tensor_item;
1013       tensor_item.set_finished(false);
1014       if (tensor_size - size_iter <= g_chunk_size) {
1015         chunk_size = tensor_size - size_iter;
1016         tensor_item.set_finished(true);
1017       }
1018       AddTensorProtoInfo(&tensor_item, tensor);
1019       // return empty tensor if didn't find the requested tensor
1020 
1021       tensor_item.set_tensor_content(data_ptr[result_index] + size_iter, chunk_size);
1022 
1023       tensor_item.set_data_type((debugger::DataType)dtype[result_index]);
1024       for (auto &elem : shape[result_index]) {
1025         tensor_item.add_dims(elem);
1026       }
1027       // add tensor to result list and increment result_index to check next item in ret_name
1028       tensor_list.push_back(tensor_item);
1029       if (size_iter > INT_MAX - g_chunk_size) {
1030         MS_EXCEPTION(ValueError) << size_iter << " + " << g_chunk_size << " would lead to integer overflow!";
1031       }
1032       size_iter += g_chunk_size;
1033     }
1034     result_index++;
1035   }
1036   return tensor_list;
1037 }
1038 
LoadTensorsBase(const ProtoVector<TensorProto> & tensors) const1039 std::list<TensorBase> Debugger::LoadTensorsBase(const ProtoVector<TensorProto> &tensors) const {
1040   std::list<TensorBase> tensor_base_list;
1041   std::vector<std::string> name;
1042   std::transform(tensors.begin(), tensors.end(), std::back_inserter(name), GetTensorFullName);
1043   std::vector<std::tuple<std::string, std::shared_ptr<TensorData>>> result_list;
1044   debug_services_->SearchNodesTensors(name, &result_list);
1045   for (auto result : result_list) {
1046     auto tensor = std::get<1>(result);
1047     if (!tensor) {
1048       // tensor was not found, creating empty tensor base.
1049       TensorBase tensor_base_item;
1050       tensor_base_item.set_data_size(0);
1051       tensor_base_item.set_data_type(0);
1052       tensor_base_item.add_shape(0);
1053       tensor_base_list.push_back(tensor_base_item);
1054       continue;
1055     }
1056     // tensor was found creating tensor base object.
1057     TensorBase tensor_base_item;
1058     tensor_base_item.set_data_size((int64_t)tensor->GetByteSize());
1059     tensor_base_item.set_data_type((int32_t)tensor->GetType());
1060     for (auto elem : tensor->GetShape()) {
1061       tensor_base_item.add_shape(elem);
1062     }
1063     tensor_base_list.push_back(tensor_base_item);
1064   }
1065   return tensor_base_list;
1066 }
1067 
LoadTensorsStat(const ProtoVector<TensorProto> & tensors) const1068 std::list<TensorSummary> Debugger::LoadTensorsStat(const ProtoVector<TensorProto> &tensors) const {
1069   std::list<TensorSummary> tensor_summary_list;
1070   std::vector<std::string> name;
1071   std::transform(tensors.begin(), tensors.end(), std::back_inserter(name), GetTensorFullName);
1072   std::vector<std::tuple<std::string, std::shared_ptr<TensorData>>> result_list;
1073   debug_services_->SearchNodesTensors(name, &result_list);
1074   for (auto result : result_list) {
1075     auto tensor = std::get<1>(result);
1076     if (!tensor) {
1077       // tensor was not found, creating empty tensor summary.
1078       DebugServices::TensorStat tensor_stat;
1079       AddTensorStatInfo(tensor_stat, &tensor_summary_list);
1080       continue;
1081     }
1082     // tensor was found creating tensor summary object.
1083     DebugServices::TensorStat tensor_stat = debug_services_->GetTensorStatistics(tensor);
1084     AddTensorStatInfo(tensor_stat, &tensor_summary_list);
1085   }
1086   return tensor_summary_list;
1087 }
1088 
Exit(bool exit_success)1089 void Debugger::Exit(bool exit_success) {
1090   // debugger will notify main thread to exit because main thread can only exit at step boundary.
1091   MS_LOG(INFO) << "Exit Debugger";
1092   SetEnableHeartbeat(false);
1093   pipeline::GraphExecutorPy::DebugTerminate(true, exit_success);
1094 }
1095 
CheckWatchpoints(const std::string & watchnode,const CNodePtr & kernel,bool recheck)1096 std::list<WatchpointHit> Debugger::CheckWatchpoints(const std::string &watchnode, const CNodePtr &kernel,
1097                                                     bool recheck) {
1098   std::vector<std::string> name;
1099   std::vector<std::string> slot;
1100   std::vector<int> condition;
1101   std::vector<unsigned int> watchpoint_id;
1102   std::vector<std::string> overflow_ops;
1103   std::vector<std::vector<DebugServices::parameter_t>> parameters;
1104   std::vector<int32_t> error_codes;
1105   std::vector<std::shared_ptr<TensorData>> tensor_list;
1106   if (watchnode.empty()) {
1107     tensor_list = debug_services_->GetTensor();
1108   } else {
1109     tensor_list = debug_services_->GetNodeTensor(kernel);
1110   }
1111   std::vector<std::string> file_list;
1112   MS_LOG(INFO) << "checkwatchpoints call for step " << num_step_;
1113   debug_services_->CheckWatchpoints(&name, &slot, &condition, &watchpoint_id, &parameters, &error_codes, overflow_ops,
1114                                     file_list, &tensor_list, initial_suspend_, watchnode.empty(), recheck);
1115   std::list<WatchpointHit> hits;
1116   for (unsigned int i = 0; i < name.size(); i++) {
1117     WatchpointHit hit;
1118     std::vector<DebugServices::parameter_t> &parameter = parameters[i];
1119     hit.set_id(watchpoint_id[i]);
1120     hit.set_error_code(error_codes[i]);
1121     // here TensorProto act as a tensor indicator, not sending tensor content
1122     TensorProto *tensor_item = hit.mutable_tensor();
1123     tensor_item->set_node_name(name[i]);
1124     tensor_item->set_slot(slot[i]);
1125     tensor_item->set_finished(true);
1126 
1127     WatchCondition *condition_item = hit.mutable_watch_condition();
1128     condition_item->set_condition(debugger::WatchCondition_Condition(condition[i]));
1129     for (const auto &p : parameter) {
1130       auto x = condition_item->mutable_params()->Add();
1131       x->set_name(p.name);
1132       x->set_disabled(p.disabled);
1133       x->set_value(p.value);
1134       x->set_hit(p.hit);
1135       x->set_actual_value(p.actual_value);
1136     }
1137     hits.push_back(hit);
1138   }
1139   return hits;
1140 }
1141 
SendWatchpoints(const std::list<WatchpointHit> & points)1142 void Debugger::SendWatchpoints(const std::list<WatchpointHit> &points) {
1143   // send info about watchpoint
1144   if (!points.empty()) {
1145     MS_EXCEPTION_IF_NULL(grpc_client_);
1146     EventReply reply = grpc_client_->SendWatchpointHits(points);
1147     if (reply.status() != reply.OK) {
1148       MS_LOG(ERROR) << "Error: SendWatchpointHits failed";
1149     }
1150   }
1151 }
1152 
DumpTensorToFile(const std::string & tensor_name,bool trans_flag,const std::string & filepath,const std::string & host_fmt,const std::vector<int64_t> & host_shape,TypeId host_type,TypeId device_type,const std::string & addr_format,size_t slot) const1153 bool Debugger::DumpTensorToFile(const std::string &tensor_name, bool trans_flag, const std::string &filepath,
1154                                 const std::string &host_fmt, const std::vector<int64_t> &host_shape, TypeId host_type,
1155                                 TypeId device_type, const std::string &addr_format, size_t slot) const {
1156   return debug_services_.get()->DumpTensorToFile(tensor_name, trans_flag, filepath, host_fmt, host_shape, host_type,
1157                                                  device_type, addr_format, slot);
1158 }
1159 
LoadNewTensor(const std::shared_ptr<TensorData> & tensor,bool keep_prev)1160 bool Debugger::LoadNewTensor(const std::shared_ptr<TensorData> &tensor, bool keep_prev) {
1161   return debug_services_.get()->LoadNewTensor(tensor, keep_prev);
1162 }
1163 
debugger_enabled() const1164 bool Debugger::debugger_enabled() const { return debugger_enabled_; }
1165 
GetCommand(const EventReply & reply)1166 DebuggerCommand GetCommand(const EventReply &reply) {
1167   DebuggerCommand cmd = DebuggerCommand::kUnknownCMD;
1168   switch (reply.cmd_case()) {
1169     case debugger::EventReply::CmdCase::kExit:
1170       cmd = DebuggerCommand::kExitCMD;
1171       break;
1172     case debugger::EventReply::CmdCase::kRunCmd:
1173       cmd = DebuggerCommand::kRunCMD;
1174       break;
1175     case debugger::EventReply::CmdCase::kSetCmd:
1176       cmd = DebuggerCommand::kSetCMD;
1177       break;
1178     case debugger::EventReply::CmdCase::kViewCmd:
1179       cmd = DebuggerCommand::kViewCMD;
1180       break;
1181     case debugger::EventReply::CmdCase::kVersionMatched:
1182       cmd = DebuggerCommand::kVersionMatchedCMD;
1183       break;
1184     default:
1185       MS_LOG(DEBUG) << "Debug: UnknownCMD";
1186       break;
1187   }
1188   return cmd;
1189 }
1190 
GetParameters(const EventReply & reply)1191 ProtoVector<WatchCondition_Parameter> GetParameters(const EventReply &reply) {
1192   if (!reply.has_set_cmd() || !reply.set_cmd().has_watch_condition()) {
1193     MS_LOG(ERROR) << "Error: Can not get Parameters from command. Returning default value: ProtoVector<Parameter>().";
1194     return ProtoVector<WatchCondition_Parameter>();
1195   }
1196   return reply.set_cmd().watch_condition().params();
1197 }
1198 
GetWatchnodes(const EventReply & reply)1199 ProtoVector<WatchNode> GetWatchnodes(const EventReply &reply) {
1200   if (!reply.has_set_cmd()) {
1201     MS_LOG(ERROR) << "Error: Not SetCMD, can not get WatchNodes. Returning default value: ProtoVector<WatchNode>().";
1202     return ProtoVector<WatchNode>();
1203   }
1204   return reply.set_cmd().watch_nodes();
1205 }
1206 
GetRunLevel(const EventReply & reply)1207 std::string GetRunLevel(const EventReply &reply) {
1208   if (!reply.has_run_cmd()) {
1209     MS_LOG(ERROR) << "Error: Not RunCMD, can not get RunLevel. Returning default value: "
1210                      "";
1211     return "";
1212   }
1213   return reply.run_cmd().run_level();
1214 }
1215 
GetNodeName(const EventReply & reply)1216 std::string GetNodeName(const EventReply &reply) {
1217   if (!reply.has_run_cmd()) {
1218     MS_LOG(ERROR) << "Error: Not RunCMD, can not get NodeName. Returning default value: "
1219                      "";
1220     return "";
1221   }
1222   return reply.run_cmd().node_name();
1223 }
1224 
GetWatchcondition(const EventReply & reply)1225 WatchCondition GetWatchcondition(const EventReply &reply) {
1226   if (!reply.has_set_cmd() || !reply.set_cmd().has_watch_condition()) {
1227     MS_LOG(ERROR) << "Error: Can not get WatchCondition from command. Returning default value: WatchCondition().";
1228     return WatchCondition();
1229   }
1230   return reply.set_cmd().watch_condition();
1231 }
1232 
GetWatchpointID(const EventReply & reply)1233 int32_t GetWatchpointID(const EventReply &reply) {
1234   if (!reply.has_set_cmd()) {
1235     MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint ID. Returning default value: 0.";
1236     return 0;
1237   }
1238   return reply.set_cmd().id();
1239 }
1240 
GetWatchpointDelete(const EventReply & reply)1241 bool GetWatchpointDelete(const EventReply &reply) {
1242   if (!reply.has_set_cmd()) {
1243     MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint delete flag. Returning default value: false.";
1244     return false;
1245   }
1246   return reply.set_cmd().delete_();
1247 }
1248 
GetTensors(const EventReply & reply)1249 ProtoVector<TensorProto> GetTensors(const EventReply &reply) {
1250   if (!reply.has_view_cmd()) {
1251     MS_LOG(ERROR) << "Error: Not ViewCMD, can not get Tensors. Returning default value: ProtoVector<TensorProto>().";
1252     return ProtoVector<TensorProto>();
1253   }
1254   return reply.view_cmd().tensors();
1255 }
1256 
GetTensorFullName(const TensorProto & tensor)1257 std::string GetTensorFullName(const TensorProto &tensor) {
1258   string node_name = tensor.node_name();
1259   if (tensor.truncate()) {
1260     // scopes in node name are separated by '/'
1261     // use the name without scope if truncate is true
1262     std::size_t found = node_name.find_last_of("/");
1263     node_name = node_name.substr(found + 1);
1264   }
1265   return node_name + ":" + tensor.slot() + (tensor.iter() == "" ? "" : ":" + tensor.iter());
1266 }
1267 
GetMiVersionMatched(const EventReply & reply)1268 bool GetMiVersionMatched(const EventReply &reply) { return reply.version_matched(); }
1269 
partial_memory() const1270 bool Debugger::partial_memory() const { return partial_memory_; }
1271 
SetEnableHeartbeat(bool enabled)1272 void Debugger::SetEnableHeartbeat(bool enabled) { enable_heartbeat_ = enabled; }
1273 
SetCurNode(const std::string & cur_name)1274 void Debugger::SetCurNode(const std::string &cur_name) {
1275   // access lock for public method
1276   std::lock_guard<std::mutex> a_lock(access_lock_);
1277   cur_name_ = cur_name;
1278 }
1279 
run_level() const1280 std::string Debugger::run_level() const { return run_level_; }
1281 
SetTrainingDone(bool training_done)1282 void Debugger::SetTrainingDone(bool training_done) { training_done_ = training_done; }
1283 
CheckPort(const std::string & port) const1284 bool Debugger::CheckPort(const std::string &port) const {
1285   int num = 0;
1286   const int min_port_num = 1;
1287   const int max_port_num = 65535;
1288   const int decimal = 10;
1289   if (port[0] == '0' && port[1] != '\0') return false;
1290   int i = 0;
1291   while (port[i] != '\0') {
1292     if (port[i] < '0' || port[i] > '9') return false;
1293     num = num * decimal + (port[i] - '0');
1294     if (num > max_port_num) return false;
1295     i++;
1296   }
1297   if (num < min_port_num) return false;
1298   return true;
1299 }
1300 
CheckIp(const std::string & host) const1301 bool Debugger::CheckIp(const std::string &host) const {
1302   std::regex reg_ip(
1303     "(25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])"
1304     "[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])"
1305     "[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])"
1306     "[.](25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])");
1307   std::smatch smat;
1308   std::string host_str = host;
1309   return std::regex_match(host_str, smat, reg_ip);
1310 }
1311 
GetFirstRunGraphId() const1312 uint32_t Debugger::GetFirstRunGraphId() const { return rungraph_id_list_.front(); }
1313 
LoadSingleAnfnode(const AnfNodePtr & anf_node,const size_t output_index)1314 void Debugger::LoadSingleAnfnode(const AnfNodePtr &anf_node, const size_t output_index) {
1315   MS_EXCEPTION_IF_NULL(anf_node);
1316   if (!anf_node->isa<Parameter>() && !anf_node->isa<ValueNode>()) {
1317     return;
1318   }
1319   // When MindRT is used, only ValueNodes and ParameterWeights can be loaded from device to host
1320   if (MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT) && (device_target_ == kGPUDevice)) {
1321     if (!anf_node->isa<ValueNode>() &&
1322         !(anf_node->isa<Parameter>() && AnfAlgo::IsParameterWeight(anf_node->cast<ParameterPtr>()))) {
1323       return;
1324     }
1325   }
1326   // for parameters and value nodes, set its execution order to be 0;
1327   int exec_order = 0;
1328   std::string node_name = GetKernelNodeName(anf_node);
1329   GetFileKernelName(NOT_NULL(&node_name));
1330   // check if output adde exists, if not, return;
1331   if (!AnfAlgo::OutputAddrExist(anf_node, output_index)) {
1332     return;
1333   }
1334   auto addr = AnfAlgo::GetOutputAddr(anf_node, output_index);
1335   MS_EXCEPTION_IF_NULL(addr);
1336   auto type = AnfAlgo::GetOutputInferDataType(anf_node, output_index);
1337   if (!IsTypeDebuggerSupported(type)) {
1338     return;
1339   }
1340   auto format = kOpFormat_DEFAULT;
1341   string tensor_name = node_name + ':' + "0";
1342   ShapeVector int_shapes = trans::GetRuntimePaddingShape(anf_node, output_index);
1343   bool keep_prev;
1344   if (anf_node->isa<Parameter>()) {
1345     keep_prev = true;
1346     debug_services_->MoveTensorCurrentToPrev(tensor_name);
1347   } else {
1348     keep_prev = false;
1349   }
1350   bool ret = addr->LoadMemToHost(tensor_name, exec_order, format, int_shapes, type, 0, keep_prev);
1351   if (!ret) {
1352     MS_LOG(ERROR) << "LoadMemToHost:"
1353                   << ", tensor_name:" << tensor_name << ", host_format:" << format << ".!";
1354   }
1355 }
1356 
LoadParametersAndConst()1357 void Debugger::LoadParametersAndConst() {
1358   if (!(debugger_enabled_ || CheckDebuggerDumpEnabled())) return;
1359   MS_EXCEPTION_IF_NULL(graph_ptr_);
1360   // load parameters
1361   MS_LOG(INFO) << "Start to load Parameters for graph " << graph_ptr_->graph_id() << ".";
1362   const auto &parameters = graph_ptr_->inputs();
1363   for (auto &item : parameters) {
1364     LoadSingleAnfnode(item, PARAMETER_OUTPUT_INDEX);
1365   }
1366   // load value nodes
1367   // get all constant values from the graph
1368   MS_LOG(INFO) << "Start to load value nodes for graph " << graph_ptr_->graph_id() << ".";
1369   const auto value_nodes = graph_ptr_->graph_value_nodes();
1370   for (auto &item : value_nodes) {
1371     LoadSingleAnfnode(item, VALUE_NODE_OUTPUT_INDEX);
1372   }
1373 }
1374 
LoadParametersAndConst(const KernelGraphPtr & graph)1375 void Debugger::LoadParametersAndConst(const KernelGraphPtr &graph) {
1376   if (!(debugger_enabled_ || CheckDebuggerDumpEnabled())) return;
1377   MS_EXCEPTION_IF_NULL(graph);
1378   MS_EXCEPTION_IF_NULL(graph_ptr_);
1379   // load parameters
1380   MS_LOG(INFO) << "Start to load Parameters for graph " << graph->graph_id() << ".";
1381   const auto &parameters = graph_ptr_->inputs();
1382   for (auto &item : parameters) {
1383     LoadSingleAnfnode(item, PARAMETER_OUTPUT_INDEX);
1384   }
1385   // load value nodes
1386   // get all constant values from the graph
1387   MS_LOG(INFO) << "Start to load value nodes for graph " << graph->graph_id() << ".";
1388   const auto value_nodes = graph_ptr_->graph_value_nodes();
1389   for (auto &item : value_nodes) {
1390     LoadSingleAnfnode(item, VALUE_NODE_OUTPUT_INDEX);
1391   }
1392 }
1393 
LoadGraphOutputs()1394 void Debugger::LoadGraphOutputs() {
1395   if (!(debugger_enabled() && device_target_ == kAscendDevice)) return;
1396   MS_EXCEPTION_IF_NULL(graph_ptr_);
1397   const auto &apply_kernels = graph_ptr_->execution_order();
1398   // for kernels, execution order starts from 1
1399   int exec_order = 1;
1400   for (const auto &node : apply_kernels) {
1401     MS_EXCEPTION_IF_NULL(node);
1402     std::string kernel_name = GetKernelNodeName(node);
1403     auto output_size = AnfAlgo::GetOutputTensorNum(node);
1404     if (partial_memory_) {
1405       if (!debug_services_->IsWatchPoint(kernel_name, node)) {
1406         continue;
1407       }
1408     }
1409     for (size_t j = 0; j < output_size; ++j) {
1410       if (!AnfAlgo::OutputAddrExist(node, j)) {
1411         MS_LOG(INFO) << "Cannot find output addr for slot " << j << " for " << kernel_name;
1412         continue;
1413       }
1414       auto addr = AnfAlgo::GetOutputAddr(node, j);
1415       MS_EXCEPTION_IF_NULL(addr);
1416       auto type = AnfAlgo::GetOutputInferDataType(node, j);
1417       if (!IsTypeDebuggerSupported(type)) {
1418         continue;
1419       }
1420       auto format = kOpFormat_DEFAULT;
1421       string tensor_name = kernel_name + ':' + std::to_string(j);
1422       ShapeVector int_shapes = trans::GetRuntimePaddingShape(node, j);
1423       auto ret = addr->LoadMemToHost(tensor_name, exec_order, format, int_shapes, type, j, false);
1424       if (!ret) {
1425         MS_LOG(ERROR) << "LoadMemToHost:"
1426                       << ", tensor_name:" << tensor_name << ", host_format:" << format << ".!";
1427       }
1428     }
1429     exec_order = exec_order + 1;
1430   }
1431 }
1432 
UpdateStepNum(const session::KernelGraph * graph)1433 void Debugger::UpdateStepNum(const session::KernelGraph *graph) {
1434   MS_EXCEPTION_IF_NULL(graph);
1435   MS_EXCEPTION_IF_NULL(debugger_);
1436   // update step number if we are processing the first graph (to support multigraph)
1437   if (device_target_ == kGPUDevice && (debugger_enabled_ || device::KernelRuntime::DumpDataEnabledIteration()) &&
1438       (graph->graph_id() == debugger_->GetFirstRunGraphId())) {
1439     // access lock for public method
1440     std::lock_guard<std::mutex> a_lock(access_lock_);
1441     ++num_step_;
1442   }
1443 }
1444 
UpdateStepNumGPU()1445 void Debugger::UpdateStepNumGPU() {
1446   // UpdateStepNum with DebugActor::DebugOnStepEnd
1447   if (device_target_ == kGPUDevice && (debugger_enabled_ || DumpDataEnabledIteration())) {
1448     // access lock for public method
1449     std::lock_guard<std::mutex> a_lock(access_lock_);
1450     ++num_step_;
1451   }
1452 }
1453 
ClearCurrentData()1454 void Debugger::ClearCurrentData() {
1455   if ((device_target_ == kGPUDevice) && (debugger_enabled_ || device::KernelRuntime::DumpDataEnabledIteration())) {
1456     if (debug_services_) {
1457       debug_services_->EmptyCurrentTensor();
1458     } else {
1459       MS_LOG(ERROR) << "debug_services_ is nullptr";
1460     }
1461   }
1462 }
1463 
TensorExistsInCurrent(const std::string & tensor_name)1464 bool Debugger::TensorExistsInCurrent(const std::string &tensor_name) {
1465   return debug_services_->TensorExistsInCurrent(tensor_name);
1466 }
1467 
1468 }  // namespace mindspore
1469