1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <dirent.h>
18 #include <cstdio>
19 #include <fstream>
20 #include <tuple>
21 #include <vector>
22 #include <algorithm>
23 #include <iostream>
24 #include <cstring>
25 #include <utility>
26 #include <map>
27 #include <regex>
28 #include "debug/debugger/debugger.h"
29 #include "debug/data_dump/dump_json_parser.h"
30 #include "pipeline/jit/pipeline.h"
31 #include "backend/session/anf_runtime_algorithm.h"
32 #include "runtime/device/kernel_runtime_manager.h"
33 #include "runtime/device/kernel_runtime.h"
34 #include "debug/data_dump/e2e_dump.h"
35 #include "utils/config_manager.h"
36 #include "debug/env_config_parser.h"
37 #include "utils/comm_manager.h"
38 #include "runtime/hardware/device_context_manager.h"
39 #include "debug/anf_ir_dump.h"
40 #include "debug/anf_ir_utils.h"
41 #ifdef ENABLE_DEBUGGER
42 #include "debug/debugger/proto_exporter.h"
43 #else
44 #include "debug/debugger/proto_exporter_stub.h"
45 #endif
46
47 using debugger::Chunk;
48 using debugger::EventReply;
49 using debugger::GraphProto;
50 using debugger::ModelProto;
51 using debugger::Statistics;
52 using debugger::TensorProto;
53 using debugger::WatchCondition;
54 using debugger::WatchCondition_Condition_inf;
55 using debugger::WatchCondition_Condition_nan;
56 using debugger::WatchCondition_Parameter;
57 using debugger::WatchNode;
58 using debugger::WatchpointHit;
59
60 namespace mindspore {
61
62 static constexpr auto g_chunk_size = 1024 * 1024 * 3;
63 static constexpr int32_t heartbeat_period_second = 30;
64 DebuggerPtr Debugger::debugger_ = nullptr;
65 std::mutex Debugger::instance_lock_;
66
Debugger()67 Debugger::Debugger()
68 : grpc_client_(nullptr),
69 debug_services_(nullptr),
70 heartbeat_thread_(nullptr),
71 device_id_(0),
72 device_target_(""),
73 num_step_(0),
74 debugger_enabled_(false),
75 suspended_at_last_kernel_(false),
76 run_level_(""),
77 node_name_(""),
78 cur_name_(""),
79 training_done_(false),
80 is_dataset_graph_(false),
81 partial_memory_(false),
82 initial_suspend_(true),
83 enable_heartbeat_(false),
84 not_dataset_graph_sum_(0),
85 version_("") {
86 CheckDebuggerEnabledParam();
87 auto ms_context = MsContext::GetInstance();
88 MS_EXCEPTION_IF_NULL(ms_context);
89 std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
90 MS_LOG(INFO) << "Debugger got device_target: " << device_target;
91 if (device_target == kCPUDevice) {
92 MS_LOG(WARNING) << "Not enabling debugger. Debugger does not support CPU.";
93 } else if (CheckDebuggerEnabled()) {
94 // configure partial memory reuse
95 partial_memory_ = CheckDebuggerPartialMemoryEnabled();
96
97 // switch memory reuse on or off
98 EnvConfigParser::GetInstance().SetSysMemreuse(partial_memory_);
99 // print some message about memory reuse to user
100 if (partial_memory_) {
101 MS_LOG(WARNING)
102 << "Partial Memory Reuse is enabled. Note: 1. Please only set watchpoints before running the first "
103 "step. 2. Tensor values are only available for nodes that are watched by any watchpoint.";
104 } else {
105 MS_LOG(WARNING)
106 << "Memory Reuse is disabled. Set environment variable MS_DEBUGGER_PARTIAL_MEM=1 to reduce memory "
107 "usage for large models.";
108 }
109 }
110 }
111
Init(const uint32_t device_id,const std::string device_target)112 void Debugger::Init(const uint32_t device_id, const std::string device_target) {
113 // access lock for public method
114 std::lock_guard<std::mutex> a_lock(access_lock_);
115 // save device_id
116 MS_LOG(INFO) << "Debugger got device_id: " << device_id;
117 device_id_ = device_id;
118 MS_LOG(INFO) << "Debugger got device_target: " << device_target;
119 device_target_ = device_target;
120 version_ = "1.5.0";
121 }
122
IsTypeDebuggerSupported(TypeId type)123 bool IsTypeDebuggerSupported(TypeId type) {
124 if (type < TypeId::kNumberTypeEnd && type > TypeId::kNumberTypeBegin && type != kNumberTypeComplex64) {
125 return true;
126 }
127 MS_LOG(INFO) << "Debugger does not support type: " << TypeIdLabel(type);
128 return false;
129 }
130
EnableDebugger()131 void Debugger::EnableDebugger() {
132 // reset some of the class members
133 num_step_ = 0;
134 debugger_enabled_ = false;
135 enable_heartbeat_ = false;
136 partial_memory_ = false;
137 grpc_client_ = nullptr;
138 debug_services_ = nullptr;
139 heartbeat_thread_ = nullptr;
140
141 // see if dump using debugger backend is enabled
142 bool dump_enabled = CheckDebuggerDumpEnabled();
143 MS_LOG(INFO) << "dump using debugger backend = " << dump_enabled;
144
145 // check if debugger enabled
146 debugger_enabled_ = CheckDebuggerEnabled();
147 MS_LOG(INFO) << "debugger_enabled_ = " << debugger_enabled_;
148
149 if (!debugger_enabled_ && !dump_enabled) {
150 MS_LOG(INFO) << "Not enabling debugger. Set environment variable ENABLE_MS_DEBUGGER=1 to enable debugger.";
151 return;
152 }
153
154 if (debugger_enabled_) {
155 // configure grpc host
156 std::string env_host_str = common::GetEnv("MS_DEBUGGER_HOST");
157 std::string host;
158 if (!env_host_str.empty()) {
159 if (CheckIp(env_host_str)) {
160 MS_LOG(INFO) << "Getenv MS_DEBUGGER_HOST: " << env_host_str;
161 host = env_host_str;
162 } else {
163 debugger_enabled_ = false;
164 MS_EXCEPTION(ValueError) << "Environment variable MS_DEBUGGER_HOST isn't a valid IP address. "
165 "Please set environment variable MS_DEBUGGER_HOST=x.x.x.x to a valid IP";
166 }
167 } else {
168 MS_LOG(INFO) << "Environment variable MS_DEBUGGER_HOST doesn't exist. Using default debugger host: localhost";
169 host = "localhost";
170 }
171 // configure grpc port
172 std::string env_port_str = common::GetEnv("MS_DEBUGGER_PORT");
173 std::string port;
174 if (!env_port_str.empty()) {
175 if (CheckPort(env_port_str)) {
176 MS_LOG(INFO) << "Getenv MS_DEBUGGER_PORT: " << env_port_str;
177 port = env_port_str;
178 } else {
179 debugger_enabled_ = false;
180 MS_EXCEPTION(ValueError) << "Environment variable MS_DEBUGGER_PORT is not valid. Custom port ranging from 1 to "
181 "65535";
182 }
183 } else {
184 port = "50051";
185 if (!CheckPort(port)) {
186 MS_EXCEPTION(ValueError) << "Default MS_DEBUGGER_PORT is not valid. Custom port ranging from 1 to 65535";
187 }
188 MS_LOG(INFO) << "Environment variable MS_DEBUGGER_PORT doesn't exist. Using default debugger port: 50051";
189 }
190 // initialize grpc client
191 grpc_client_ = std::make_unique<GrpcClient>(host, port);
192 // initialize sending heartbeat
193 heartbeat_thread_ = std::make_unique<std::thread>([this]() { SendHeartbeat(heartbeat_period_second); });
194 }
195 debug_services_ = std::make_unique<DebugServices>();
196 }
197
CheckDatasetSinkMode()198 void Debugger::CheckDatasetSinkMode() {
199 if (CheckDebuggerDumpEnabled() && ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
200 MS_EXCEPTION(NotSupportError)
201 << "e2e_dump not supported on GPU with dataset_sink_mode=True. Please set dataset_sink_mode=False";
202 }
203
204 if (CheckDebuggerEnabled() && ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
205 MS_EXCEPTION(NotSupportError)
206 << "Debugger is not supported with dataset_sink_mode=True. Please set dataset_sink_mode=False";
207 }
208 }
209
CheckDebuggerDumpEnabled() const210 bool Debugger::CheckDebuggerDumpEnabled() const {
211 // see if dump is enabled
212 if (device_target_ == kGPUDevice) {
213 return device::KernelRuntime::DumpDataEnabled();
214 }
215 return false;
216 }
217
CheckDebuggerEnabled() const218 bool Debugger::CheckDebuggerEnabled() const {
219 // get env variables to configure debugger
220 std::string env_enable_str = common::GetEnv("ENABLE_MS_DEBUGGER");
221 if (!env_enable_str.empty()) {
222 (void)std::transform(env_enable_str.begin(), env_enable_str.end(), env_enable_str.begin(), ::tolower);
223 if ((env_enable_str == "1" || env_enable_str == "true") && device_target_ != kCPUDevice) {
224 return true;
225 }
226 }
227 return false;
228 }
229
CheckDebuggerEnabledParam() const230 void Debugger::CheckDebuggerEnabledParam() const {
231 // check the value of env variable ENABLE_MS_DEBUGGER
232 std::string env_enable_str = common::GetEnv("ENABLE_MS_DEBUGGER");
233 if (!env_enable_str.empty()) {
234 (void)std::transform(env_enable_str.begin(), env_enable_str.end(), env_enable_str.begin(), ::tolower);
235 if (env_enable_str != "0" && env_enable_str != "1" && env_enable_str != "false" && env_enable_str != "true") {
236 MS_LOG(WARNING) << "Env variable ENABLE_MS_DEBUGGER should be True/False/1/0 (case insensitive), but get: "
237 << env_enable_str;
238 }
239 }
240 }
241
CheckDebuggerPartialMemoryEnabled() const242 bool Debugger::CheckDebuggerPartialMemoryEnabled() const {
243 std::string env_partial_mem_str = common::GetEnv("MS_DEBUGGER_PARTIAL_MEM");
244 if (!env_partial_mem_str.empty()) {
245 MS_LOG(INFO) << "Getenv MS_DEBUGGER_PARTIAL_MEM: " << env_partial_mem_str;
246 if (env_partial_mem_str == "1") {
247 return true;
248 }
249 }
250 return false;
251 }
252
DebuggerBackendEnabled() const253 bool Debugger::DebuggerBackendEnabled() const { return CheckDebuggerDumpEnabled() || CheckDebuggerEnabled(); }
254
Reset()255 void Debugger::Reset() {
256 // access lock for public method
257 std::lock_guard<std::mutex> a_lock(access_lock_);
258 // reset components
259 if (heartbeat_thread_ && heartbeat_thread_->joinable()) {
260 SetEnableHeartbeat(false);
261 heartbeat_thread_->join();
262 MS_LOG(INFO) << "Join Heartbeat thread.";
263 }
264 heartbeat_thread_ = nullptr;
265 device_id_ = 0;
266 device_target_ = "";
267 num_step_ = 0;
268 debugger_enabled_ = false;
269 is_dataset_graph_ = false;
270 partial_memory_ = false;
271 graph_ptr_ = nullptr;
272 grpc_client_ = nullptr;
273 debug_services_ = nullptr;
274 graph_proto_list_.clear();
275 graph_ptr_list_.clear();
276 graph_ptr_step_vec_.clear();
277 MS_LOG(INFO) << "Release Debugger resource.";
278 }
279
PreExecuteGraphDebugger(const std::vector<KernelGraphPtr> & graphs)280 void Debugger::PreExecuteGraphDebugger(const std::vector<KernelGraphPtr> &graphs) {
281 // Only GPU is supported for MindRTBackend
282 if (device_target_ != kGPUDevice) {
283 return;
284 }
285 E2eDump::UpdateIterGPUDump();
286 // Store graphs that are run in one step.
287 graph_ptr_step_vec_ = graphs;
288 for (size_t graph_index = 0; graph_index < graphs.size(); ++graph_index) {
289 const auto &graph = graphs[graph_index];
290 if (debugger_) {
291 debugger_->PreExecute(graph);
292 }
293 DumpSetup(graph);
294 }
295 }
296
PreExecute(const KernelGraphPtr & graph_ptr)297 void Debugger::PreExecute(const KernelGraphPtr &graph_ptr) {
298 MS_EXCEPTION_IF_NULL(graph_ptr);
299 // access lock for public method
300 std::lock_guard<std::mutex> a_lock(access_lock_);
301 CheckDatasetSinkMode();
302 auto graph_id = graph_ptr->graph_id();
303 // collect rungrap_ids to update step number in multigraph case
304 if (!rungraph_id_list_.size()) {
305 rungraph_id_list_.push_back(graph_id);
306
307 } else {
308 if (std::find(rungraph_id_list_.begin(), rungraph_id_list_.end(), graph_id) == rungraph_id_list_.end()) {
309 rungraph_id_list_.push_back(graph_id);
310 }
311 }
312 // multiple graphs
313 if (graph_proto_list_.size() > 1) {
314 // there are more than one graphs are not dataset_graph
315 if (not_dataset_graph_sum_ > 0) {
316 SendMultiGraphsAndClear(graph_ptr);
317 }
318 } else if (graph_proto_list_.size() == 1) {
319 // single graph, and not the initial step
320 if (device_target_ == kGPUDevice && num_step_ != 0) {
321 if (debugger_enabled_ && !(run_level_ == "node" && suspended_at_last_kernel_)) {
322 CommandLoop();
323 }
324 debug_services_->ResetLoadedTensors();
325 }
326 // In single graph case, reset graph_ptr_ to be nullptr for the initial step
327 if (num_step_ == 0) {
328 graph_ptr_ = nullptr;
329 CheckGraphPtr(graph_ptr);
330 }
331 } else if (debugger_enabled_ && graph_id == rungraph_id_list_.front() && device_target_ == kGPUDevice) {
332 // Multiple graph, and not the initial step,
333 // stop only when receive the first sub run graph for each step
334 // if we have stopped for the last kernel before, no need to stop again
335 if (pipeline::GraphExecutorPy::GetDebugTerminate()) {
336 return;
337 }
338 if (!(run_level_ == "node" && suspended_at_last_kernel_)) {
339 CommandLoop();
340 }
341 debug_services_->ResetLoadedTensors();
342 }
343 // resets for the new graph
344 suspended_at_last_kernel_ = false;
345 }
346
SendMultiGraphsAndClear(const KernelGraphPtr & graph_ptr)347 void Debugger::SendMultiGraphsAndClear(const KernelGraphPtr &graph_ptr) {
348 // only try to enable debugger if they are not all dataset graphs
349 if (!debugger_enabled_) {
350 EnableDebugger();
351 }
352 if (debugger_enabled_) {
353 // only send compiled graphs once at the initial step.
354 auto dbg_graph_ptr = graph_ptr_;
355 // use current graph ptr to load parameters
356 graph_ptr_ = graph_ptr;
357 LoadParametersAndConst();
358 // revert graph ptr to original value
359 graph_ptr_ = dbg_graph_ptr;
360
361 SendMultiGraphsAndSuspend(graph_proto_list_);
362
363 graph_proto_list_.clear();
364 }
365 }
366
DumpDataEnabledIteration() const367 bool Debugger::DumpDataEnabledIteration() const {
368 auto &dump_json_parser = DumpJsonParser::GetInstance();
369 if (!dump_json_parser.e2e_dump_enabled()) {
370 return false;
371 }
372
373 auto cur_iter = dump_json_parser.cur_dump_iter();
374 if (dump_json_parser.IsDumpIter(cur_iter)) {
375 return true;
376 }
377 return false;
378 }
379
GetRankID()380 uint32_t Debugger::GetRankID() {
381 auto ms_context = MsContext::GetInstance();
382 MS_EXCEPTION_IF_NULL(ms_context);
383 std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
384 uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
385 const auto &device_context =
386 device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_target, device_id});
387 uint32_t rank_id = device_context->GetRankID();
388 return rank_id;
389 }
390
Dump(const KernelGraphPtr & kernel_graph) const391 void Debugger::Dump(const KernelGraphPtr &kernel_graph) const {
392 uint32_t rank_id = GetRankID();
393 if (debugger_ && debugger_->DebuggerBackendEnabled()) {
394 MS_EXCEPTION_IF_NULL(kernel_graph);
395 (void)E2eDump::DumpParametersAndConstData(kernel_graph.get(), rank_id, debugger_.get());
396 } else {
397 DumpJsonParser::GetInstance().UpdateDumpIter();
398 }
399 }
400
DumpSingleNode(const CNodePtr & node,uint32_t graph_id)401 void Debugger::DumpSingleNode(const CNodePtr &node, uint32_t graph_id) {
402 if (debugger_ && debugger_->DebuggerBackendEnabled()) {
403 uint32_t rank_id = GetRankID();
404 (void)E2eDump::DumpSingleNodeData(node, graph_id, rank_id, debugger_.get());
405 }
406 }
407
DumpSetup(const KernelGraphPtr & kernel_graph) const408 void Debugger::DumpSetup(const KernelGraphPtr &kernel_graph) const {
409 MS_LOG(INFO) << "Start!";
410 MS_EXCEPTION_IF_NULL(kernel_graph);
411 E2eDump::DumpSetup(kernel_graph.get());
412 MS_LOG(INFO) << "Finish!";
413 }
414
DumpInGraphCompiler(const KernelGraphPtr & kernel_graph)415 void Debugger::DumpInGraphCompiler(const KernelGraphPtr &kernel_graph) {
416 // This function will be called for new GPU runtime using MindRTBackend
417 auto &json_parser = DumpJsonParser::GetInstance();
418 if (json_parser.e2e_dump_enabled()) {
419 uint32_t rank_id = GetRankID();
420 kernel_graph->set_root_graph_id(kernel_graph->graph_id());
421 std::string final_graph = "trace_code_graph_" + std::to_string(kernel_graph->graph_id());
422 std::string root_dir = json_parser.path() + "/rank_" + std::to_string(rank_id);
423 std::string target_dir = root_dir + "/graphs";
424 std::string ir_file_path = target_dir + "/" + "ms_output_" + final_graph + ".ir";
425 DumpIRProtoWithSrcInfo(kernel_graph, final_graph, target_dir, kDebugWholeStack);
426 DumpIR("trace_code_graph", kernel_graph, true, kWholeStack, ir_file_path);
427 DumpGraphExeOrder("ms_execution_order_graph_" + std::to_string(kernel_graph->graph_id()) + ".csv", root_dir,
428 kernel_graph->execution_order());
429 }
430 }
431
PostExecuteGraphDebugger()432 void Debugger::PostExecuteGraphDebugger() {
433 // On CPU, update dump iteration, Parameters and consts are not dumped here
434 if (device_target_ == kCPUDevice) {
435 DumpJsonParser::GetInstance().UpdateDumpIter();
436 return;
437 }
438 // Only GPU is supported for MindRTBackend
439 if (device_target_ != kGPUDevice) {
440 return;
441 }
442 // LoadParametersAndConst for all the graphs that have been run in the current step
443 if (debugger_) {
444 for (auto graph : graph_ptr_step_vec_) {
445 debugger_->LoadParametersAndConst(graph);
446 }
447 }
448 // debug used for dump
449 if (debugger_ && debugger_->CheckDebuggerDumpEnabled()) {
450 // Dump Parameters and consts
451 for (auto graph : graph_ptr_step_vec_) {
452 debugger_->Dump(graph);
453 if (!debugger_->debugger_enabled()) {
454 debugger_->ClearCurrentData();
455 }
456 }
457 }
458 if (debugger_) {
459 debugger_->PostExecute();
460 }
461 }
462
PostExecute()463 void Debugger::PostExecute() {
464 // access lock for public method
465 std::lock_guard<std::mutex> a_lock(access_lock_);
466 if (pipeline::GraphExecutorPy::GetDebugTerminate()) {
467 return;
468 }
469 if (debugger_ && debugger_->DebuggerBackendEnabled()) {
470 // analyze tensor data and send the watchpoints been hit
471 if (debugger_enabled_ && !is_dataset_graph_) {
472 if (device_target_ != kGPUDevice) {
473 num_step_++;
474 }
475 SendWatchpoints(CheckWatchpoints());
476
477 // no need to suspend at each graph for GPU, suspension happens in preExecute
478 if (device_target_ != kGPUDevice) {
479 CommandLoop();
480 }
481 }
482 // Only keep parameters in the current map
483 // GPU ResetLoadedTensors happens in preExecute
484 if (device_target_ != kGPUDevice) {
485 debug_services_->ResetLoadedTensors();
486 }
487 }
488 }
489
ReadNodeDataRequired(const CNodePtr & kernel) const490 bool Debugger::ReadNodeDataRequired(const CNodePtr &kernel) const {
491 if (debugger_enabled_ && !is_dataset_graph_) {
492 auto is_watchpoint = debug_services_->IsWatchPoint(cur_name_, kernel);
493 // if node has a watchpoint on it, is next_to node, or continue_to node then read the kernel tensor data
494 if (is_watchpoint || (run_level_ == "node" && (node_name_ == "" || node_name_ == cur_name_))) {
495 return true;
496 }
497 }
498 return false;
499 }
500
PostExecuteNode(const CNodePtr & kernel,bool last_kernel)501 void Debugger::PostExecuteNode(const CNodePtr &kernel, bool last_kernel) {
502 // access lock for public method
503 std::lock_guard<std::mutex> a_lock(access_lock_);
504 if (pipeline::GraphExecutorPy::GetDebugTerminate()) {
505 return;
506 }
507 if (debugger_enabled_ && !is_dataset_graph_) {
508 auto is_watchpoint = debug_services_->IsWatchPoint(cur_name_, kernel);
509
510 // if kernel is watchpoint,and get hit. suspend.
511 bool hit_empty_flag = true;
512 if (is_watchpoint) {
513 auto hits = CheckWatchpoints(cur_name_, kernel);
514 if (!hits.empty()) {
515 SendWatchpoints(hits);
516 CommandLoop();
517
518 hit_empty_flag = false;
519 }
520 }
521 if (hit_empty_flag && run_level_ == "node" && (node_name_ == "" || node_name_ == cur_name_)) {
522 // if kernel is not watchpoint and is next_to or continue_to node, suspend
523 // sets a bool to be checked in preExecute to avoid double stopping at last kernel in the last graph
524 if (last_kernel) {
525 suspended_at_last_kernel_ = true;
526 }
527 CommandLoop();
528 }
529 return;
530 }
531 }
532
LoadGraphs(const KernelGraphPtr & graph_ptr)533 void Debugger::LoadGraphs(const KernelGraphPtr &graph_ptr) {
534 MS_EXCEPTION_IF_NULL(graph_ptr);
535 if (graph_ptr_ != graph_ptr) {
536 MS_LOG(INFO) << "LoadGraphs Debugger got new graph: " << graph_ptr->graph_id();
537 // save new graph_ptr
538 graph_ptr_ = graph_ptr;
539 CheckDatasetGraph();
540 if (!is_dataset_graph_) {
541 // get proto for new graph_ptr
542 auto graph_proto = GetGraphProto(graph_ptr);
543 // add new graph proto to graph_proto_list_
544 graph_proto_list_.push_back(graph_proto);
545 graph_ptr_list_.push_back(graph_ptr);
546 not_dataset_graph_sum_++;
547 }
548 // reset is_dataset_graph to be false
549 is_dataset_graph_ = false;
550 }
551 }
552
553 // In single graph cases, check single graph ptr
CheckGraphPtr(const KernelGraphPtr & graph_ptr)554 void Debugger::CheckGraphPtr(const KernelGraphPtr &graph_ptr) {
555 MS_EXCEPTION_IF_NULL(graph_ptr);
556 if (graph_ptr_ != graph_ptr) {
557 MS_LOG(INFO) << "CheckGraphPtr Debugger got new graph: " << graph_ptr->graph_id();
558 // save new graph_ptr
559 graph_ptr_ = graph_ptr;
560 if (!is_dataset_graph_) {
561 // only try to enable debugger if it is not a dataset graph
562 EnableDebugger();
563 if (debugger_enabled_) {
564 LoadParametersAndConst();
565 // get graph proto and send to Mindinsight
566 auto graph_proto = graph_proto_list_.front();
567 SendGraphAndSuspend(graph_proto);
568 }
569 }
570 }
571 }
572
CheckDatasetGraph()573 void Debugger::CheckDatasetGraph() {
574 // print parameter node names
575 MS_EXCEPTION_IF_NULL(graph_ptr_);
576 const auto ¶ms = graph_ptr_->inputs();
577 for (const auto ¶m : params) {
578 MS_LOG(INFO) << "param: " << GetKernelNodeName(param);
579 }
580 // check if there is GetNext or InitDataSetQueue node
581 const auto &nodes = graph_ptr_->execution_order();
582 for (const auto &node : nodes) {
583 auto node_name = AnfAlgo::GetCNodeName(node);
584 MS_LOG(INFO) << "node: " << GetKernelNodeName(node);
585 if (node_name == "GetNext" || node_name == "InitDataSetQueue") {
586 MS_LOG(INFO) << "Not enabling debugger for graph " << graph_ptr_->graph_id() << ": found dataset graph node "
587 << node_name;
588 is_dataset_graph_ = true;
589 return;
590 }
591 }
592 is_dataset_graph_ = false;
593 }
594
GetGraphProto(const KernelGraphPtr & graph_ptr) const595 GraphProto Debugger::GetGraphProto(const KernelGraphPtr &graph_ptr) const {
596 // convert kernel graph to debugger modelproto
597 ModelProto model = GetDebuggerFuncGraphProto(graph_ptr);
598 return model.graph();
599 }
600
SendHeartbeat(int32_t period)601 void Debugger::SendHeartbeat(int32_t period) {
602 int num_heartbeat_fail = 0;
603 const int max_num_heartbeat_fail = 5;
604 const int retry_milliseconds = 500;
605
606 Heartbeat heartbeat;
607 heartbeat.set_message("Debugger is alive");
608 heartbeat.set_period(heartbeat_period_second);
609
610 SetEnableHeartbeat(CheckDebuggerEnabled());
611 while (enable_heartbeat_) {
612 MS_EXCEPTION_IF_NULL(grpc_client_);
613 EventReply reply = grpc_client_->SendHeartbeat(heartbeat);
614 if (reply.status() != reply.OK) {
615 MS_LOG(ERROR) << "Error: SendHeartbeat failed";
616 num_heartbeat_fail++;
617 if (num_heartbeat_fail >= max_num_heartbeat_fail) {
618 MS_LOG(ERROR) << "Maximum number of failure for SendHeartbeat reached : exiting training session.";
619 SetEnableHeartbeat(false);
620 break;
621 } else {
622 MS_LOG(ERROR) << "Number of consecutive SendHeartbeat fail:" << num_heartbeat_fail;
623 std::this_thread::sleep_for(std::chrono::milliseconds(retry_milliseconds));
624 }
625 } else {
626 std::this_thread::sleep_for(std::chrono::milliseconds(period * 1000));
627 }
628 }
629 }
630
SendGraphAndSuspend(const GraphProto & graph_proto)631 void Debugger::SendGraphAndSuspend(const GraphProto &graph_proto) {
632 if (SendMetadata(true)) {
633 // send graph to Mindinsight server
634 MS_EXCEPTION_IF_NULL(grpc_client_);
635 EventReply reply = grpc_client_->SendGraph(graph_proto);
636 if (reply.status() != reply.OK) {
637 MS_LOG(ERROR) << "Error: SendGraph failed";
638 }
639 // enter command loop, wait and process commands
640 CommandLoop();
641 }
642 }
643
SendMetadata(bool version_check)644 bool Debugger::SendMetadata(bool version_check) {
645 // prepare metadata
646 MS_EXCEPTION_IF_NULL(graph_ptr_);
647 std::string device_name = std::to_string(device_id_) + ":" + std::to_string(graph_ptr_->graph_id());
648 Metadata metadata;
649 metadata.set_device_name(device_name);
650 metadata.set_cur_step(num_step_);
651 metadata.set_backend(device_target_);
652 metadata.set_cur_node(cur_name_);
653 metadata.set_training_done(training_done_);
654 metadata.set_ms_version(version_);
655 MS_LOG(INFO) << "Is training done?" << training_done_;
656 // set graph number to not_dataset_graph_sum_
657 metadata.set_graph_num(not_dataset_graph_sum_);
658
659 MS_EXCEPTION_IF_NULL(grpc_client_);
660 EventReply reply_metadata = grpc_client_->SendMetadata(metadata);
661
662 bool ret = false;
663 if (reply_metadata.status() == reply_metadata.OK) {
664 if (version_check) {
665 // get type of the command in meta data reply, it should be version matched
666 DebuggerCommand cmd = GetCommand(reply_metadata);
667 if (cmd != DebuggerCommand::kVersionMatchedCMD) {
668 MS_LOG(ERROR) << "MindInsight version is too old, Mindspore version is " << version_;
669 Exit();
670 } else {
671 if (GetMiVersionMatched(reply_metadata)) {
672 MS_LOG(INFO) << "MindSpore version is " << version_ << " matches MindInsight version.";
673 ret = true;
674 } else {
675 MS_LOG(ERROR) << "MindSpore version " << version_ << ", did not match MindInsight version.";
676 CommandLoop();
677 }
678 }
679 } else {
680 // version check is done before so we can just return true here
681 ret = true;
682 }
683 } else {
684 MS_LOG(ERROR) << "Error: SendMetadata failed";
685 }
686
687 return ret;
688 }
689
SendMultiGraphsAndSuspend(const std::list<GraphProto> & graph_proto_list)690 void Debugger::SendMultiGraphsAndSuspend(const std::list<GraphProto> &graph_proto_list) {
691 if (!SendMetadata(true)) {
692 return;
693 }
694 MS_EXCEPTION_IF_NULL(grpc_client_);
695 // send multiple graphs to mindinght server
696 // split graph into chunks if one graph is larger than chunk size
697 std::list<Chunk> chunked_graph_proto_list;
698 Chunk chunk;
699 for (auto graph : graph_proto_list) {
700 std::string str = graph.SerializeAsString();
701 auto graph_size = graph.ByteSize();
702 if (graph_size > g_chunk_size) {
703 auto sub_graph_str = grpc_client_->ChunkString(str, graph_size);
704
705 for (unsigned int i = 0; i < sub_graph_str.size(); i++) {
706 chunk.set_buffer(sub_graph_str[i]);
707 if (i < sub_graph_str.size() - 1) {
708 chunk.set_finished(false);
709 } else {
710 chunk.set_finished(true);
711 }
712 chunked_graph_proto_list.push_back(chunk);
713 }
714 } else {
715 chunk.set_buffer(str);
716 chunk.set_finished(true);
717 chunked_graph_proto_list.push_back(chunk);
718 }
719 }
720 EventReply reply = grpc_client_->SendMultiGraphs(chunked_graph_proto_list);
721 if (reply.status() != reply.OK) {
722 MS_LOG(ERROR) << "Error: SendGraph failed";
723 }
724 // enter command loop, wait and process commands
725 CommandLoop();
726 }
727
CommandLoop()728 void Debugger::CommandLoop() {
729 // prepare metadata
730 MS_EXCEPTION_IF_NULL(graph_ptr_);
731 std::string device_name = std::to_string(device_id_) + ":" + std::to_string(graph_ptr_->graph_id());
732 Metadata metadata;
733
734 metadata.set_device_name(device_name);
735 metadata.set_cur_step(num_step_);
736 metadata.set_backend(device_target_);
737 metadata.set_cur_node(cur_name_);
738 metadata.set_training_done(training_done_);
739
740 // loop exit flag
741 bool run = false;
742 int num_wait_fail = 0;
743 const int max_num_wait_fail = 5;
744
745 while (!run) {
746 // wait for command
747 MS_EXCEPTION_IF_NULL(grpc_client_);
748 EventReply reply = grpc_client_->WaitForCommand(metadata);
749 if (reply.status() != reply.OK) {
750 MS_LOG(ERROR) << "Error: WaitForCommand failed";
751 num_wait_fail++;
752 if (num_wait_fail > max_num_wait_fail) {
753 MS_LOG(ERROR) << "Maximum number of WaitForCommand retry reached: exiting training session.";
754 MS_LOG(ERROR) << "Failed to connect to MindInsight debugger server. Please check the config "
755 "of debugger host and port.";
756 Exit();
757 run = true;
758 } else {
759 MS_LOG(ERROR) << "Number of consecutive WaitForCommand fail:" << num_wait_fail << "; Retry after "
760 << num_wait_fail << "s";
761 std::this_thread::sleep_for(std::chrono::seconds(num_wait_fail));
762 }
763 continue;
764 }
765
766 // get type of the command in reply
767 DebuggerCommand cmd = GetCommand(reply);
768 if (cmd == DebuggerCommand::kUnknownCMD) {
769 MS_LOG(DEBUG) << "Debug: debugger received unknown command";
770 continue;
771 }
772
773 MS_LOG(INFO) << "received command: ";
774 switch (cmd) {
775 case DebuggerCommand::kUnknownCMD:
776 MS_LOG(INFO) << "UnknownCMD";
777 break;
778 case DebuggerCommand::kExitCMD:
779 MS_LOG(INFO) << "ExitCMD";
780 Exit(true);
781 // Used for debugger termination
782 run = true;
783 break;
784 case DebuggerCommand::kRunCMD:
785 ProcessRunCMD(reply);
786 if (GetRunLevel(reply) != "recheck") {
787 // exit loop
788 run = true;
789 }
790 break;
791 case DebuggerCommand::kSetCMD:
792 ProcessKSetCMD(reply);
793 break;
794 case DebuggerCommand::kViewCMD:
795 ProcessKViewCMD(reply);
796 break;
797 case DebuggerCommand::kVersionMatchedCMD:
798 MS_LOG(ERROR) << "Received unexpected Version Matched CMD from Mindinsight.";
799 Exit();
800 break;
801 default:
802 MS_LOG(ERROR) << "Received unknown CMD from Mindinsight";
803 Exit();
804 break;
805 }
806 }
807 }
808
ProcessRunCMD(const EventReply & reply)809 void Debugger::ProcessRunCMD(const EventReply &reply) {
810 MS_LOG(INFO) << "RunCMD";
811 if (GetRunLevel(reply) == "recheck") {
812 MS_LOG(INFO) << "rechecking all watchpoints";
813 SendWatchpoints(CheckWatchpoints("", nullptr, true));
814 } else {
815 // no longer the initial suspension.
816 initial_suspend_ = false;
817 // print run cmd content
818 // get run_level and node_name
819 run_level_ = GetRunLevel(reply);
820 node_name_ = GetNodeName(reply);
821
822 MS_LOG(INFO) << "run_level: " << run_level_;
823 MS_LOG(INFO) << "node_name_: " << node_name_;
824 }
825 }
826
ProcessKSetCMD(const EventReply & reply)827 void Debugger::ProcessKSetCMD(const EventReply &reply) {
828 MS_LOG(INFO) << "SetCMD";
829 MS_LOG(INFO) << "id: " << GetWatchpointID(reply);
830 MS_LOG(INFO) << "delete: " << GetWatchpointDelete(reply);
831 if (GetWatchpointDelete(reply)) {
832 MS_LOG(INFO) << "Deleting watchpoint";
833 RemoveWatchpoint(GetWatchpointID(reply));
834 } else {
835 MS_LOG(INFO) << "Setting watchpoint";
836 MS_LOG(INFO) << "condition: " << GetWatchcondition(reply).condition();
837 ProtoVector<WatchNode> recieved_nodes = GetWatchnodes(reply);
838 for (const auto &node : recieved_nodes) {
839 MS_LOG(INFO) << "node name: " << node.node_name();
840 MS_LOG(INFO) << "node type: " << node.node_type();
841 }
842 ProtoVector<WatchCondition_Parameter> parameters = GetParameters(reply);
843 for (const auto ¶meter : parameters) {
844 MS_LOG(INFO) << "parameter name: " << parameter.name();
845 MS_LOG(INFO) << "parameter is disabled: " << parameter.disabled();
846 MS_LOG(INFO) << "parameter value: " << parameter.value();
847 }
848 SetWatchpoint(GetWatchnodes(reply), GetWatchcondition(reply), GetWatchpointID(reply), GetParameters(reply));
849 }
850 }
851
ProcessKViewCMD(const EventReply & reply)852 void Debugger::ProcessKViewCMD(const EventReply &reply) {
853 MS_LOG(INFO) << "ViewCMD";
854 // print view cmd content
855 ProtoVector<TensorProto> received_tensors = GetTensors(reply);
856 for (auto received_tensor : received_tensors) {
857 MS_LOG(INFO) << "tensor node name: " << received_tensor.node_name();
858 MS_LOG(INFO) << "tensor slot: " << received_tensor.slot();
859 MS_LOG(INFO) << "tensor finished: " << std::boolalpha << received_tensor.finished() << std::noboolalpha;
860 MS_LOG(INFO) << "tensor iter: " << received_tensor.iter();
861 MS_LOG(INFO) << "tensor truncate: " << std::boolalpha << received_tensor.truncate() << std::noboolalpha;
862 }
863
864 switch (reply.view_cmd().level()) {
865 case debugger::ViewCMD_Level::ViewCMD_Level_base:
866 MS_LOG(INFO) << "Tensor base request.";
867 ViewBaseLevel(reply);
868 break;
869
870 case debugger::ViewCMD_Level::ViewCMD_Level_statistics:
871 MS_LOG(INFO) << "Tensor statistics request.";
872 ViewStatLevel(reply);
873 break;
874
875 case debugger::ViewCMD_Level::ViewCMD_Level_value:
876 MS_LOG(INFO) << "Tensor value request.";
877 ViewValueLevel(reply);
878 break;
879 default:
880 MS_LOG(DEBUG) << "Debug: Unknown tensor info level";
881 break;
882 }
883 }
884
ViewValueLevel(const EventReply & reply)885 void Debugger::ViewValueLevel(const EventReply &reply) {
886 MS_LOG(INFO) << "Sending tensors";
887 std::list<TensorProto> tensors = LoadTensors(GetTensors(reply));
888 // print view cmd reply
889 for (auto tensor : tensors) {
890 MS_LOG(INFO) << "tensor node name: " << tensor.node_name();
891 MS_LOG(INFO) << "tensor slot: " << tensor.slot();
892 MS_LOG(INFO) << "tensor finished: " << std::boolalpha << tensor.finished() << std::noboolalpha;
893 MS_LOG(INFO) << "tensor iter: " << tensor.iter();
894 MS_LOG(INFO) << "tensor truncate: " << std::boolalpha << tensor.truncate() << std::noboolalpha;
895 MS_LOG(INFO) << "tensor dims: ";
896 for (auto dim : tensor.dims()) {
897 MS_LOG(INFO) << dim << ",";
898 }
899 MS_LOG(INFO) << "tensor dtype: " << tensor.data_type();
900 }
901 MS_EXCEPTION_IF_NULL(grpc_client_);
902 EventReply send_tensors_reply = grpc_client_->SendTensors(tensors);
903 if (send_tensors_reply.status() != debugger::EventReply::OK) {
904 MS_LOG(ERROR) << "Error: SendTensors failed";
905 }
906 }
907
ViewStatLevel(const EventReply & reply)908 void Debugger::ViewStatLevel(const EventReply &reply) {
909 std::list<TensorSummary> tensor_stats_list = LoadTensorsStat(GetTensors(reply));
910 EventReply send_tensors_stat_reply = grpc_client_->SendTensorStats(tensor_stats_list);
911 if (send_tensors_stat_reply.status() != debugger::EventReply::OK) {
912 MS_LOG(ERROR) << "Error: SendTensorsStats failed.";
913 }
914 }
915
ViewBaseLevel(const EventReply & reply)916 void Debugger::ViewBaseLevel(const EventReply &reply) {
917 std::list<TensorBase> tensor_base_list = LoadTensorsBase(GetTensors(reply));
918 EventReply send_tensor_base_reply = grpc_client_->SendTensorBase(tensor_base_list);
919 if (send_tensor_base_reply.status() != debugger::EventReply::OK) {
920 MS_LOG(ERROR) << "Error: SendTensorsBase failed.";
921 }
922 }
923
AddTensorProtoInfo(TensorProto * tensor_item,const TensorProto & tensor)924 void AddTensorProtoInfo(TensorProto *tensor_item, const TensorProto &tensor) {
925 tensor_item->set_node_name(tensor.node_name());
926 tensor_item->set_slot(tensor.slot());
927 tensor_item->set_iter(tensor.iter());
928 tensor_item->set_truncate(tensor.truncate());
929 tensor_item->clear_tensor_content();
930 tensor_item->clear_data_type();
931 tensor_item->clear_dims();
932 }
933
AddTensorStatInfo(const DebugServices::TensorStat & tensor_stat,std::list<TensorSummary> * const tensor_summary_list)934 void AddTensorStatInfo(const DebugServices::TensorStat &tensor_stat,
935 std::list<TensorSummary> *const tensor_summary_list) {
936 if (tensor_summary_list == nullptr) {
937 MS_LOG(DEBUG) << "tensor_summary_list is nullptr.";
938 return;
939 }
940 TensorSummary tensor_summary_item;
941 TensorBase *tensor_base = tensor_summary_item.mutable_tensor_base();
942 tensor_base->set_data_type(tensor_stat.dtype);
943 tensor_base->set_data_size((int64_t)tensor_stat.data_size);
944 for (auto elem : tensor_stat.shape) {
945 tensor_base->add_shape(elem);
946 }
947
948 Statistics *tensor_statistics = tensor_summary_item.mutable_statistics();
949 tensor_statistics->set_is_bool(tensor_stat.is_bool);
950 tensor_statistics->set_max_value(static_cast<float>(tensor_stat.max_value));
951 tensor_statistics->set_min_value(static_cast<float>(tensor_stat.min_value));
952 tensor_statistics->set_avg_value(static_cast<float>(tensor_stat.avg_value));
953 tensor_statistics->set_count(tensor_stat.count);
954 tensor_statistics->set_neg_zero_count(tensor_stat.neg_zero_count);
955 tensor_statistics->set_pos_zero_count(tensor_stat.pos_zero_count);
956 tensor_statistics->set_nan_count(tensor_stat.nan_count);
957 tensor_statistics->set_neg_inf_count(tensor_stat.neg_inf_count);
958 tensor_statistics->set_pos_inf_count(tensor_stat.pos_inf_count);
959 tensor_statistics->set_zero_count(tensor_stat.zero_count);
960
961 tensor_summary_list->push_back(tensor_summary_item);
962 }
963
SetWatchpoint(const ProtoVector<WatchNode> & nodes,const WatchCondition & condition,const int32_t id,const ProtoVector<WatchCondition_Parameter> & parameters)964 void Debugger::SetWatchpoint(const ProtoVector<WatchNode> &nodes, const WatchCondition &condition, const int32_t id,
965 const ProtoVector<WatchCondition_Parameter> ¶meters) {
966 std::vector<std::tuple<std::string, bool>> check_node_list;
967 std::vector<DebugServices::parameter_t> parameter_list;
968
969 std::transform(nodes.begin(), nodes.end(), std::back_inserter(check_node_list),
970 [](const WatchNode &node) -> std::tuple<std::string, bool> {
971 return make_tuple(node.node_name(), node.node_type() == "scope");
972 });
973
974 std::transform(
975 parameters.begin(), parameters.end(), std::back_inserter(parameter_list),
976 [](const WatchCondition_Parameter ¶meter) -> DebugServices::parameter_t {
977 return DebugServices::parameter_t{parameter.name(), parameter.disabled(), parameter.value(), parameter.hit()};
978 });
979 debug_services_->AddWatchpoint(id, condition.condition(), condition.value(), check_node_list, parameter_list);
980 }
981
RemoveWatchpoint(const int32_t id)982 void Debugger::RemoveWatchpoint(const int32_t id) { debug_services_->RemoveWatchpoint(id); }
983
LoadTensors(const ProtoVector<TensorProto> & tensors) const984 std::list<TensorProto> Debugger::LoadTensors(const ProtoVector<TensorProto> &tensors) const {
985 std::vector<std::string> name;
986 std::vector<std::string> ret_name;
987 std::vector<const char *> data_ptr;
988 std::vector<ssize_t> data_size;
989 std::vector<unsigned int> dtype;
990 std::vector<std::vector<int64_t>> shape;
991
992 std::transform(tensors.begin(), tensors.end(), std::back_inserter(name), GetTensorFullName);
993
994 // ret_name will contain tensor names that are found in TensorLoader
995 // items in ret_name will be in the same order with tensors if found
996 debug_services_->ReadNodesTensors(name, &ret_name, &data_ptr, &data_size, &dtype, &shape);
997 std::list<TensorProto> tensor_list;
998 size_t result_index = 0;
999
1000 for (auto tensor : tensors) {
1001 ssize_t size_iter = 0;
1002 if (result_index >= ret_name.size() || ret_name[result_index] != GetTensorFullName(tensor)) {
1003 TensorProto tensor_item;
1004 tensor_item.set_finished(true);
1005 AddTensorProtoInfo(&tensor_item, tensor);
1006 tensor_list.push_back(tensor_item);
1007 continue;
1008 }
1009 ssize_t tensor_size = data_size[result_index];
1010 while (size_iter < tensor_size) {
1011 ssize_t chunk_size = g_chunk_size;
1012 TensorProto tensor_item;
1013 tensor_item.set_finished(false);
1014 if (tensor_size - size_iter <= g_chunk_size) {
1015 chunk_size = tensor_size - size_iter;
1016 tensor_item.set_finished(true);
1017 }
1018 AddTensorProtoInfo(&tensor_item, tensor);
1019 // return empty tensor if didn't find the requested tensor
1020
1021 tensor_item.set_tensor_content(data_ptr[result_index] + size_iter, chunk_size);
1022
1023 tensor_item.set_data_type((debugger::DataType)dtype[result_index]);
1024 for (auto &elem : shape[result_index]) {
1025 tensor_item.add_dims(elem);
1026 }
1027 // add tensor to result list and increment result_index to check next item in ret_name
1028 tensor_list.push_back(tensor_item);
1029 if (size_iter > INT_MAX - g_chunk_size) {
1030 MS_EXCEPTION(ValueError) << size_iter << " + " << g_chunk_size << " would lead to integer overflow!";
1031 }
1032 size_iter += g_chunk_size;
1033 }
1034 result_index++;
1035 }
1036 return tensor_list;
1037 }
1038
LoadTensorsBase(const ProtoVector<TensorProto> & tensors) const1039 std::list<TensorBase> Debugger::LoadTensorsBase(const ProtoVector<TensorProto> &tensors) const {
1040 std::list<TensorBase> tensor_base_list;
1041 std::vector<std::string> name;
1042 std::transform(tensors.begin(), tensors.end(), std::back_inserter(name), GetTensorFullName);
1043 std::vector<std::tuple<std::string, std::shared_ptr<TensorData>>> result_list;
1044 debug_services_->SearchNodesTensors(name, &result_list);
1045 for (auto result : result_list) {
1046 auto tensor = std::get<1>(result);
1047 if (!tensor) {
1048 // tensor was not found, creating empty tensor base.
1049 TensorBase tensor_base_item;
1050 tensor_base_item.set_data_size(0);
1051 tensor_base_item.set_data_type(0);
1052 tensor_base_item.add_shape(0);
1053 tensor_base_list.push_back(tensor_base_item);
1054 continue;
1055 }
1056 // tensor was found creating tensor base object.
1057 TensorBase tensor_base_item;
1058 tensor_base_item.set_data_size((int64_t)tensor->GetByteSize());
1059 tensor_base_item.set_data_type((int32_t)tensor->GetType());
1060 for (auto elem : tensor->GetShape()) {
1061 tensor_base_item.add_shape(elem);
1062 }
1063 tensor_base_list.push_back(tensor_base_item);
1064 }
1065 return tensor_base_list;
1066 }
1067
LoadTensorsStat(const ProtoVector<TensorProto> & tensors) const1068 std::list<TensorSummary> Debugger::LoadTensorsStat(const ProtoVector<TensorProto> &tensors) const {
1069 std::list<TensorSummary> tensor_summary_list;
1070 std::vector<std::string> name;
1071 std::transform(tensors.begin(), tensors.end(), std::back_inserter(name), GetTensorFullName);
1072 std::vector<std::tuple<std::string, std::shared_ptr<TensorData>>> result_list;
1073 debug_services_->SearchNodesTensors(name, &result_list);
1074 for (auto result : result_list) {
1075 auto tensor = std::get<1>(result);
1076 if (!tensor) {
1077 // tensor was not found, creating empty tensor summary.
1078 DebugServices::TensorStat tensor_stat;
1079 AddTensorStatInfo(tensor_stat, &tensor_summary_list);
1080 continue;
1081 }
1082 // tensor was found creating tensor summary object.
1083 DebugServices::TensorStat tensor_stat = debug_services_->GetTensorStatistics(tensor);
1084 AddTensorStatInfo(tensor_stat, &tensor_summary_list);
1085 }
1086 return tensor_summary_list;
1087 }
1088
Exit(bool exit_success)1089 void Debugger::Exit(bool exit_success) {
1090 // debugger will notify main thread to exit because main thread can only exit at step boundary.
1091 MS_LOG(INFO) << "Exit Debugger";
1092 SetEnableHeartbeat(false);
1093 pipeline::GraphExecutorPy::DebugTerminate(true, exit_success);
1094 }
1095
CheckWatchpoints(const std::string & watchnode,const CNodePtr & kernel,bool recheck)1096 std::list<WatchpointHit> Debugger::CheckWatchpoints(const std::string &watchnode, const CNodePtr &kernel,
1097 bool recheck) {
1098 std::vector<std::string> name;
1099 std::vector<std::string> slot;
1100 std::vector<int> condition;
1101 std::vector<unsigned int> watchpoint_id;
1102 std::vector<std::string> overflow_ops;
1103 std::vector<std::vector<DebugServices::parameter_t>> parameters;
1104 std::vector<int32_t> error_codes;
1105 std::vector<std::shared_ptr<TensorData>> tensor_list;
1106 if (watchnode.empty()) {
1107 tensor_list = debug_services_->GetTensor();
1108 } else {
1109 tensor_list = debug_services_->GetNodeTensor(kernel);
1110 }
1111 std::vector<std::string> file_list;
1112 MS_LOG(INFO) << "checkwatchpoints call for step " << num_step_;
1113 debug_services_->CheckWatchpoints(&name, &slot, &condition, &watchpoint_id, ¶meters, &error_codes, overflow_ops,
1114 file_list, &tensor_list, initial_suspend_, watchnode.empty(), recheck);
1115 std::list<WatchpointHit> hits;
1116 for (unsigned int i = 0; i < name.size(); i++) {
1117 WatchpointHit hit;
1118 std::vector<DebugServices::parameter_t> ¶meter = parameters[i];
1119 hit.set_id(watchpoint_id[i]);
1120 hit.set_error_code(error_codes[i]);
1121 // here TensorProto act as a tensor indicator, not sending tensor content
1122 TensorProto *tensor_item = hit.mutable_tensor();
1123 tensor_item->set_node_name(name[i]);
1124 tensor_item->set_slot(slot[i]);
1125 tensor_item->set_finished(true);
1126
1127 WatchCondition *condition_item = hit.mutable_watch_condition();
1128 condition_item->set_condition(debugger::WatchCondition_Condition(condition[i]));
1129 for (const auto &p : parameter) {
1130 auto x = condition_item->mutable_params()->Add();
1131 x->set_name(p.name);
1132 x->set_disabled(p.disabled);
1133 x->set_value(p.value);
1134 x->set_hit(p.hit);
1135 x->set_actual_value(p.actual_value);
1136 }
1137 hits.push_back(hit);
1138 }
1139 return hits;
1140 }
1141
SendWatchpoints(const std::list<WatchpointHit> & points)1142 void Debugger::SendWatchpoints(const std::list<WatchpointHit> &points) {
1143 // send info about watchpoint
1144 if (!points.empty()) {
1145 MS_EXCEPTION_IF_NULL(grpc_client_);
1146 EventReply reply = grpc_client_->SendWatchpointHits(points);
1147 if (reply.status() != reply.OK) {
1148 MS_LOG(ERROR) << "Error: SendWatchpointHits failed";
1149 }
1150 }
1151 }
1152
DumpTensorToFile(const std::string & tensor_name,bool trans_flag,const std::string & filepath,const std::string & host_fmt,const std::vector<int64_t> & host_shape,TypeId host_type,TypeId device_type,const std::string & addr_format,size_t slot) const1153 bool Debugger::DumpTensorToFile(const std::string &tensor_name, bool trans_flag, const std::string &filepath,
1154 const std::string &host_fmt, const std::vector<int64_t> &host_shape, TypeId host_type,
1155 TypeId device_type, const std::string &addr_format, size_t slot) const {
1156 return debug_services_.get()->DumpTensorToFile(tensor_name, trans_flag, filepath, host_fmt, host_shape, host_type,
1157 device_type, addr_format, slot);
1158 }
1159
LoadNewTensor(const std::shared_ptr<TensorData> & tensor,bool keep_prev)1160 bool Debugger::LoadNewTensor(const std::shared_ptr<TensorData> &tensor, bool keep_prev) {
1161 return debug_services_.get()->LoadNewTensor(tensor, keep_prev);
1162 }
1163
debugger_enabled() const1164 bool Debugger::debugger_enabled() const { return debugger_enabled_; }
1165
GetCommand(const EventReply & reply)1166 DebuggerCommand GetCommand(const EventReply &reply) {
1167 DebuggerCommand cmd = DebuggerCommand::kUnknownCMD;
1168 switch (reply.cmd_case()) {
1169 case debugger::EventReply::CmdCase::kExit:
1170 cmd = DebuggerCommand::kExitCMD;
1171 break;
1172 case debugger::EventReply::CmdCase::kRunCmd:
1173 cmd = DebuggerCommand::kRunCMD;
1174 break;
1175 case debugger::EventReply::CmdCase::kSetCmd:
1176 cmd = DebuggerCommand::kSetCMD;
1177 break;
1178 case debugger::EventReply::CmdCase::kViewCmd:
1179 cmd = DebuggerCommand::kViewCMD;
1180 break;
1181 case debugger::EventReply::CmdCase::kVersionMatched:
1182 cmd = DebuggerCommand::kVersionMatchedCMD;
1183 break;
1184 default:
1185 MS_LOG(DEBUG) << "Debug: UnknownCMD";
1186 break;
1187 }
1188 return cmd;
1189 }
1190
GetParameters(const EventReply & reply)1191 ProtoVector<WatchCondition_Parameter> GetParameters(const EventReply &reply) {
1192 if (!reply.has_set_cmd() || !reply.set_cmd().has_watch_condition()) {
1193 MS_LOG(ERROR) << "Error: Can not get Parameters from command. Returning default value: ProtoVector<Parameter>().";
1194 return ProtoVector<WatchCondition_Parameter>();
1195 }
1196 return reply.set_cmd().watch_condition().params();
1197 }
1198
GetWatchnodes(const EventReply & reply)1199 ProtoVector<WatchNode> GetWatchnodes(const EventReply &reply) {
1200 if (!reply.has_set_cmd()) {
1201 MS_LOG(ERROR) << "Error: Not SetCMD, can not get WatchNodes. Returning default value: ProtoVector<WatchNode>().";
1202 return ProtoVector<WatchNode>();
1203 }
1204 return reply.set_cmd().watch_nodes();
1205 }
1206
GetRunLevel(const EventReply & reply)1207 std::string GetRunLevel(const EventReply &reply) {
1208 if (!reply.has_run_cmd()) {
1209 MS_LOG(ERROR) << "Error: Not RunCMD, can not get RunLevel. Returning default value: "
1210 "";
1211 return "";
1212 }
1213 return reply.run_cmd().run_level();
1214 }
1215
GetNodeName(const EventReply & reply)1216 std::string GetNodeName(const EventReply &reply) {
1217 if (!reply.has_run_cmd()) {
1218 MS_LOG(ERROR) << "Error: Not RunCMD, can not get NodeName. Returning default value: "
1219 "";
1220 return "";
1221 }
1222 return reply.run_cmd().node_name();
1223 }
1224
GetWatchcondition(const EventReply & reply)1225 WatchCondition GetWatchcondition(const EventReply &reply) {
1226 if (!reply.has_set_cmd() || !reply.set_cmd().has_watch_condition()) {
1227 MS_LOG(ERROR) << "Error: Can not get WatchCondition from command. Returning default value: WatchCondition().";
1228 return WatchCondition();
1229 }
1230 return reply.set_cmd().watch_condition();
1231 }
1232
GetWatchpointID(const EventReply & reply)1233 int32_t GetWatchpointID(const EventReply &reply) {
1234 if (!reply.has_set_cmd()) {
1235 MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint ID. Returning default value: 0.";
1236 return 0;
1237 }
1238 return reply.set_cmd().id();
1239 }
1240
GetWatchpointDelete(const EventReply & reply)1241 bool GetWatchpointDelete(const EventReply &reply) {
1242 if (!reply.has_set_cmd()) {
1243 MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint delete flag. Returning default value: false.";
1244 return false;
1245 }
1246 return reply.set_cmd().delete_();
1247 }
1248
GetTensors(const EventReply & reply)1249 ProtoVector<TensorProto> GetTensors(const EventReply &reply) {
1250 if (!reply.has_view_cmd()) {
1251 MS_LOG(ERROR) << "Error: Not ViewCMD, can not get Tensors. Returning default value: ProtoVector<TensorProto>().";
1252 return ProtoVector<TensorProto>();
1253 }
1254 return reply.view_cmd().tensors();
1255 }
1256
GetTensorFullName(const TensorProto & tensor)1257 std::string GetTensorFullName(const TensorProto &tensor) {
1258 string node_name = tensor.node_name();
1259 if (tensor.truncate()) {
1260 // scopes in node name are separated by '/'
1261 // use the name without scope if truncate is true
1262 std::size_t found = node_name.find_last_of("/");
1263 node_name = node_name.substr(found + 1);
1264 }
1265 return node_name + ":" + tensor.slot() + (tensor.iter() == "" ? "" : ":" + tensor.iter());
1266 }
1267
GetMiVersionMatched(const EventReply & reply)1268 bool GetMiVersionMatched(const EventReply &reply) { return reply.version_matched(); }
1269
partial_memory() const1270 bool Debugger::partial_memory() const { return partial_memory_; }
1271
SetEnableHeartbeat(bool enabled)1272 void Debugger::SetEnableHeartbeat(bool enabled) { enable_heartbeat_ = enabled; }
1273
SetCurNode(const std::string & cur_name)1274 void Debugger::SetCurNode(const std::string &cur_name) {
1275 // access lock for public method
1276 std::lock_guard<std::mutex> a_lock(access_lock_);
1277 cur_name_ = cur_name;
1278 }
1279
run_level() const1280 std::string Debugger::run_level() const { return run_level_; }
1281
SetTrainingDone(bool training_done)1282 void Debugger::SetTrainingDone(bool training_done) { training_done_ = training_done; }
1283
CheckPort(const std::string & port) const1284 bool Debugger::CheckPort(const std::string &port) const {
1285 int num = 0;
1286 const int min_port_num = 1;
1287 const int max_port_num = 65535;
1288 const int decimal = 10;
1289 if (port[0] == '0' && port[1] != '\0') return false;
1290 int i = 0;
1291 while (port[i] != '\0') {
1292 if (port[i] < '0' || port[i] > '9') return false;
1293 num = num * decimal + (port[i] - '0');
1294 if (num > max_port_num) return false;
1295 i++;
1296 }
1297 if (num < min_port_num) return false;
1298 return true;
1299 }
1300
CheckIp(const std::string & host) const1301 bool Debugger::CheckIp(const std::string &host) const {
1302 std::regex reg_ip(
1303 "(25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])"
1304 "[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])"
1305 "[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])"
1306 "[.](25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])");
1307 std::smatch smat;
1308 std::string host_str = host;
1309 return std::regex_match(host_str, smat, reg_ip);
1310 }
1311
GetFirstRunGraphId() const1312 uint32_t Debugger::GetFirstRunGraphId() const { return rungraph_id_list_.front(); }
1313
LoadSingleAnfnode(const AnfNodePtr & anf_node,const size_t output_index)1314 void Debugger::LoadSingleAnfnode(const AnfNodePtr &anf_node, const size_t output_index) {
1315 MS_EXCEPTION_IF_NULL(anf_node);
1316 if (!anf_node->isa<Parameter>() && !anf_node->isa<ValueNode>()) {
1317 return;
1318 }
1319 // When MindRT is used, only ValueNodes and ParameterWeights can be loaded from device to host
1320 if (MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT) && (device_target_ == kGPUDevice)) {
1321 if (!anf_node->isa<ValueNode>() &&
1322 !(anf_node->isa<Parameter>() && AnfAlgo::IsParameterWeight(anf_node->cast<ParameterPtr>()))) {
1323 return;
1324 }
1325 }
1326 // for parameters and value nodes, set its execution order to be 0;
1327 int exec_order = 0;
1328 std::string node_name = GetKernelNodeName(anf_node);
1329 GetFileKernelName(NOT_NULL(&node_name));
1330 // check if output adde exists, if not, return;
1331 if (!AnfAlgo::OutputAddrExist(anf_node, output_index)) {
1332 return;
1333 }
1334 auto addr = AnfAlgo::GetOutputAddr(anf_node, output_index);
1335 MS_EXCEPTION_IF_NULL(addr);
1336 auto type = AnfAlgo::GetOutputInferDataType(anf_node, output_index);
1337 if (!IsTypeDebuggerSupported(type)) {
1338 return;
1339 }
1340 auto format = kOpFormat_DEFAULT;
1341 string tensor_name = node_name + ':' + "0";
1342 ShapeVector int_shapes = trans::GetRuntimePaddingShape(anf_node, output_index);
1343 bool keep_prev;
1344 if (anf_node->isa<Parameter>()) {
1345 keep_prev = true;
1346 debug_services_->MoveTensorCurrentToPrev(tensor_name);
1347 } else {
1348 keep_prev = false;
1349 }
1350 bool ret = addr->LoadMemToHost(tensor_name, exec_order, format, int_shapes, type, 0, keep_prev);
1351 if (!ret) {
1352 MS_LOG(ERROR) << "LoadMemToHost:"
1353 << ", tensor_name:" << tensor_name << ", host_format:" << format << ".!";
1354 }
1355 }
1356
LoadParametersAndConst()1357 void Debugger::LoadParametersAndConst() {
1358 if (!(debugger_enabled_ || CheckDebuggerDumpEnabled())) return;
1359 MS_EXCEPTION_IF_NULL(graph_ptr_);
1360 // load parameters
1361 MS_LOG(INFO) << "Start to load Parameters for graph " << graph_ptr_->graph_id() << ".";
1362 const auto ¶meters = graph_ptr_->inputs();
1363 for (auto &item : parameters) {
1364 LoadSingleAnfnode(item, PARAMETER_OUTPUT_INDEX);
1365 }
1366 // load value nodes
1367 // get all constant values from the graph
1368 MS_LOG(INFO) << "Start to load value nodes for graph " << graph_ptr_->graph_id() << ".";
1369 const auto value_nodes = graph_ptr_->graph_value_nodes();
1370 for (auto &item : value_nodes) {
1371 LoadSingleAnfnode(item, VALUE_NODE_OUTPUT_INDEX);
1372 }
1373 }
1374
LoadParametersAndConst(const KernelGraphPtr & graph)1375 void Debugger::LoadParametersAndConst(const KernelGraphPtr &graph) {
1376 if (!(debugger_enabled_ || CheckDebuggerDumpEnabled())) return;
1377 MS_EXCEPTION_IF_NULL(graph);
1378 MS_EXCEPTION_IF_NULL(graph_ptr_);
1379 // load parameters
1380 MS_LOG(INFO) << "Start to load Parameters for graph " << graph->graph_id() << ".";
1381 const auto ¶meters = graph_ptr_->inputs();
1382 for (auto &item : parameters) {
1383 LoadSingleAnfnode(item, PARAMETER_OUTPUT_INDEX);
1384 }
1385 // load value nodes
1386 // get all constant values from the graph
1387 MS_LOG(INFO) << "Start to load value nodes for graph " << graph->graph_id() << ".";
1388 const auto value_nodes = graph_ptr_->graph_value_nodes();
1389 for (auto &item : value_nodes) {
1390 LoadSingleAnfnode(item, VALUE_NODE_OUTPUT_INDEX);
1391 }
1392 }
1393
LoadGraphOutputs()1394 void Debugger::LoadGraphOutputs() {
1395 if (!(debugger_enabled() && device_target_ == kAscendDevice)) return;
1396 MS_EXCEPTION_IF_NULL(graph_ptr_);
1397 const auto &apply_kernels = graph_ptr_->execution_order();
1398 // for kernels, execution order starts from 1
1399 int exec_order = 1;
1400 for (const auto &node : apply_kernels) {
1401 MS_EXCEPTION_IF_NULL(node);
1402 std::string kernel_name = GetKernelNodeName(node);
1403 auto output_size = AnfAlgo::GetOutputTensorNum(node);
1404 if (partial_memory_) {
1405 if (!debug_services_->IsWatchPoint(kernel_name, node)) {
1406 continue;
1407 }
1408 }
1409 for (size_t j = 0; j < output_size; ++j) {
1410 if (!AnfAlgo::OutputAddrExist(node, j)) {
1411 MS_LOG(INFO) << "Cannot find output addr for slot " << j << " for " << kernel_name;
1412 continue;
1413 }
1414 auto addr = AnfAlgo::GetOutputAddr(node, j);
1415 MS_EXCEPTION_IF_NULL(addr);
1416 auto type = AnfAlgo::GetOutputInferDataType(node, j);
1417 if (!IsTypeDebuggerSupported(type)) {
1418 continue;
1419 }
1420 auto format = kOpFormat_DEFAULT;
1421 string tensor_name = kernel_name + ':' + std::to_string(j);
1422 ShapeVector int_shapes = trans::GetRuntimePaddingShape(node, j);
1423 auto ret = addr->LoadMemToHost(tensor_name, exec_order, format, int_shapes, type, j, false);
1424 if (!ret) {
1425 MS_LOG(ERROR) << "LoadMemToHost:"
1426 << ", tensor_name:" << tensor_name << ", host_format:" << format << ".!";
1427 }
1428 }
1429 exec_order = exec_order + 1;
1430 }
1431 }
1432
UpdateStepNum(const session::KernelGraph * graph)1433 void Debugger::UpdateStepNum(const session::KernelGraph *graph) {
1434 MS_EXCEPTION_IF_NULL(graph);
1435 MS_EXCEPTION_IF_NULL(debugger_);
1436 // update step number if we are processing the first graph (to support multigraph)
1437 if (device_target_ == kGPUDevice && (debugger_enabled_ || device::KernelRuntime::DumpDataEnabledIteration()) &&
1438 (graph->graph_id() == debugger_->GetFirstRunGraphId())) {
1439 // access lock for public method
1440 std::lock_guard<std::mutex> a_lock(access_lock_);
1441 ++num_step_;
1442 }
1443 }
1444
UpdateStepNumGPU()1445 void Debugger::UpdateStepNumGPU() {
1446 // UpdateStepNum with DebugActor::DebugOnStepEnd
1447 if (device_target_ == kGPUDevice && (debugger_enabled_ || DumpDataEnabledIteration())) {
1448 // access lock for public method
1449 std::lock_guard<std::mutex> a_lock(access_lock_);
1450 ++num_step_;
1451 }
1452 }
1453
ClearCurrentData()1454 void Debugger::ClearCurrentData() {
1455 if ((device_target_ == kGPUDevice) && (debugger_enabled_ || device::KernelRuntime::DumpDataEnabledIteration())) {
1456 if (debug_services_) {
1457 debug_services_->EmptyCurrentTensor();
1458 } else {
1459 MS_LOG(ERROR) << "debug_services_ is nullptr";
1460 }
1461 }
1462 }
1463
TensorExistsInCurrent(const std::string & tensor_name)1464 bool Debugger::TensorExistsInCurrent(const std::string &tensor_name) {
1465 return debug_services_->TensorExistsInCurrent(tensor_name);
1466 }
1467
1468 } // namespace mindspore
1469