1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "profiler/device/ascend/memory_profiling.h"
18 #include <fstream>
19 #include <memory>
20 #include "utils/log_adapter.h"
21 #include "utils/ms_context.h"
22 #include "utils/ms_utils.h"
23 #include "nlohmann/json.hpp"
24 #include "profiler/device/ascend/ascend_profiling.h"
25
26 namespace mindspore {
27 namespace profiler {
28 namespace ascend {
29 constexpr char kOutputPath[] = "output";
30
IsMemoryProfilingEnable() const31 bool MemoryProfiling::IsMemoryProfilingEnable() const {
32 auto ascend_profiler = AscendProfiler::GetInstance();
33 MS_EXCEPTION_IF_NULL(ascend_profiler);
34 if (!ascend_profiler->GetProfilingEnableFlag()) {
35 return false;
36 }
37
38 const std::string prof_options_str = ascend_profiler->GetProfilingOptions();
39 nlohmann::json options;
40 try {
41 options = nlohmann::json::parse(prof_options_str);
42 } catch (nlohmann::json::exception &e) {
43 MS_LOG(ERROR) << "Failed to parse profiling options.";
44 return false;
45 }
46
47 if (options["profile_memory"] == "off") {
48 return false;
49 }
50
51 return true;
52 }
53
AddGraphMemoryNode(uint32_t graph_id)54 std::shared_ptr<GraphMemory> MemoryProfiling::AddGraphMemoryNode(uint32_t graph_id) {
55 std::shared_ptr<GraphMemory> node = std::make_shared<GraphMemory>(graph_id);
56 MS_EXCEPTION_IF_NULL(node);
57 graph_memory_[graph_id] = node;
58 return node;
59 }
60
GetGraphMemoryNode(uint32_t graph_id) const61 std::shared_ptr<GraphMemory> MemoryProfiling::GetGraphMemoryNode(uint32_t graph_id) const {
62 auto node = graph_memory_.find(graph_id);
63 if (node != graph_memory_.end()) {
64 return node->second;
65 }
66
67 return nullptr;
68 }
69
MemoryToPB()70 bool MemoryProfiling::MemoryToPB() {
71 memory_proto_.set_total_mem(device_mem_size_);
72 for (const auto &graph : graph_memory_) {
73 GraphMemProto *graph_proto = memory_proto_.add_graph_mem();
74 if (graph_proto == nullptr) {
75 MS_LOG(ERROR) << "Add graph memory proto failed.";
76 return false;
77 }
78 graph_proto->set_graph_id(graph.second->GetGraphId());
79 graph_proto->set_static_mem(graph.second->GetStaticMemSize());
80 // node memory to PB
81 for (const auto &node : graph.second->GetNodeMemory()) {
82 NodeMemProto *node_mem = graph_proto->add_node_mems();
83 if (node_mem == nullptr) {
84 MS_LOG(ERROR) << "Add node memory proto failed.";
85 return false;
86 }
87 node_mem->set_node_name(node.GetNodeName());
88 node_mem->set_node_id(node.GetNodeId());
89 for (const auto &id : node.GetInputTensorId()) {
90 node_mem->add_input_tensor_id(id);
91 }
92 for (const auto &id : node.GetOutputTensorId()) {
93 node_mem->add_output_tensor_id(id);
94 }
95 for (const auto &id : node.GetOutputTensorId()) {
96 node_mem->add_workspace_tensor_id(id);
97 }
98 }
99 // tensor memory to PB
100 for (const auto &node : graph.second->GetTensorMemory()) {
101 TensorMemProto *tensor_mem = graph_proto->add_tensor_mems();
102 if (tensor_mem == nullptr) {
103 MS_LOG(ERROR) << "Add node memory proto failed.";
104 return false;
105 }
106 tensor_mem->set_tensor_id(node.GetTensorId());
107 tensor_mem->set_size(node.GetAlignedSize());
108 std::string type = node.GetType();
109 tensor_mem->set_type(type);
110 tensor_mem->set_life_start(node.GetLifeStart());
111 tensor_mem->set_life_end(node.GetLifeEnd());
112 std::string life_long = node.GetLifeLong();
113 tensor_mem->set_life_long(life_long);
114 }
115 }
116 MS_LOG(INFO) << "Memory profiling data to PB end.";
117 return true;
118 }
119
GetOutputPath() const120 std::string MemoryProfiling::GetOutputPath() const {
121 auto ascend_profiler = AscendProfiler::GetInstance();
122 MS_EXCEPTION_IF_NULL(ascend_profiler);
123 const std::string options_str = ascend_profiler->GetProfilingOptions();
124 nlohmann::json options_json;
125 try {
126 options_json = nlohmann::json::parse(options_str);
127 } catch (nlohmann::json::parse_error &e) {
128 MS_LOG(EXCEPTION) << "Parse profiling option json failed, error:" << e.what();
129 }
130 auto iter = options_json.find(kOutputPath);
131 if (iter != options_json.end() && iter->is_string()) {
132 char real_path[PATH_MAX] = {0};
133 if ((*iter).size() >= PATH_MAX) {
134 MS_LOG(ERROR) << "Path is invalid for memory profiling.";
135 return "";
136 }
137 #if defined(_WIN32) || defined(_WIN64)
138 if (_fullpath(real_path, common::SafeCStr(*iter), PATH_MAX) == nullptr) {
139 MS_LOG(ERROR) << "Path is invalid for memory profiling.";
140 return "";
141 }
142 #else
143 if (realpath(common::SafeCStr(*iter), real_path) == nullptr) {
144 MS_LOG(ERROR) << "Path is invalid for memory profiling.";
145 return "";
146 }
147 #endif
148 return real_path;
149 }
150
151 MS_LOG(ERROR) << "Output path is not found when save memory profiling data";
152 return "";
153 }
154
SaveMemoryProfiling()155 void MemoryProfiling::SaveMemoryProfiling() {
156 auto context = MsContext::GetInstance();
157 MS_EXCEPTION_IF_NULL(context);
158 std::string dir_path = GetOutputPath();
159 auto device_id = common::GetEnv("RANK_ID");
160 // If RANK_ID is not set, default value is 0
161 if (device_id.empty()) {
162 device_id = "0";
163 }
164 std::string file = dir_path + std::string("/memory_usage_") + std::string(device_id) + std::string(".pb");
165
166 MemoryToPB();
167
168 std::fstream handle(file, std::ios::out | std::ios::trunc | std::ios::binary);
169 if (!memory_proto_.SerializeToOstream(&handle)) {
170 MS_LOG(ERROR) << "Save memory profiling data to file failed";
171 }
172 handle.close();
173 MS_LOG(INFO) << "Start save memory profiling data to " << file << " end";
174 return;
175 }
176 } // namespace ascend
177 } // namespace profiler
178 } // namespace mindspore
179