1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "transform/graph_ir/storage_format_convertor.h"
18
19 #include <queue>
20 #include <utility>
21 #include <vector>
22 #include <memory>
23 #include <string>
24 #include "graph/types.h"
25 #include "ops/conv_pool_ops.h"
26 #include "transform/graph_ir/storage_format_config_factory.h"
27 #include "ir/func_graph.h"
28 #include "include/common/utils/anfalgo.h"
29 #include "include/backend/kernel_info.h"
30 #include "transform/graph_ir/transform_util.h"
31 #include "plugin/device/ascend/hal/common/ascend_utils.h"
32 #include "ops/framework_op_name.h"
33 #include "ops/framework_ops.h"
34 #include "ops/nn_optimizer_ops.h"
35
36 namespace mindspore::transform {
37 namespace {
GetUsedOperator(const AnfNodePtr & node,const NodeUsersMap & node_users,const PrimitivePtr & prim)38 AnfNodePtr GetUsedOperator(const AnfNodePtr &node, const NodeUsersMap &node_users, const PrimitivePtr &prim) {
39 auto iter = node_users.find(node);
40 if (iter != node_users.end()) {
41 for (const auto &node_user : iter->second) {
42 if (common::AnfAlgo::GetCNodeName(node_user.first) == prim->name()) {
43 return node_user.first;
44 }
45 }
46 }
47 return nullptr;
48 }
49
IsUsedByConv2D(const AnfNodePtr & node,const NodeUsersMap & node_users)50 bool IsUsedByConv2D(const AnfNodePtr &node, const NodeUsersMap &node_users) {
51 auto load = GetUsedOperator(node, node_users, prim::kPrimLoad);
52 if (load == nullptr) {
53 return false;
54 }
55 if (GetUsedOperator(load, node_users, prim::kPrimConv2D) != nullptr) {
56 return true;
57 }
58 auto cast = GetUsedOperator(load, node_users, prim::kPrimCast);
59 if (cast == nullptr) {
60 return false;
61 }
62 return GetUsedOperator(cast, node_users, prim::kPrimConv2D) != nullptr;
63 }
64
IsUsedBySwitch(const AnfNodePtr & node,const NodeUsersMap & node_users)65 bool IsUsedBySwitch(const AnfNodePtr &node, const NodeUsersMap &node_users) {
66 if (common::AnfAlgo::GetCNodeName(node) != prim::kPrimPartial->name()) {
67 return false;
68 }
69
70 if (GetUsedOperator(node, node_users, prim::kPrimSwitch) != nullptr) {
71 return true;
72 }
73
74 auto make_tuple = GetUsedOperator(node, node_users, prim::kPrimMakeTuple);
75 if (make_tuple != nullptr && GetUsedOperator(make_tuple, node_users, prim::kPrimSwitchLayer) != nullptr) {
76 return true;
77 }
78
79 return false;
80 }
81
GetOutputNodesSkipVirtualNode(const FuncGraphManagerPtr & manager,const AnfNodePtr & node)82 std::vector<std::pair<AnfNodePtr, int>> GetOutputNodesSkipVirtualNode(const FuncGraphManagerPtr &manager,
83 const AnfNodePtr &node) {
84 std::vector<std::pair<AnfNodePtr, int>> res;
85 std::queue<std::pair<AnfNodePtr, int>> anf_queue;
86 std::vector<AnfNodePtr> visited;
87 MS_EXCEPTION_IF_NULL(manager);
88 auto node_users_map = manager->node_users();
89 for (const auto &node_pair : node_users_map[node]) {
90 anf_queue.push(node_pair);
91 visited.push_back(node_pair.first);
92 }
93 while (!anf_queue.empty()) {
94 auto queue_front = anf_queue.front();
95 anf_queue.pop();
96 // NOTE fix: do not support trans from NC1HWC0 to ND between parameter and Switch-op/switch_layer-op
97 auto momentum_var = GetMomentumVarByAccum(node, node_users_map);
98 if (IsUsedBySwitch(queue_front.first, node_users_map) && (momentum_var != nullptr) &&
99 !IsUsedByConv2D(momentum_var, node_users_map)) {
100 return {};
101 }
102 std::string op_name = common::AnfAlgo::GetCNodeName(queue_front.first);
103 if (AnfUtils::IsRealKernel(queue_front.first) && op_name != kCastOpName && op_name != kTensorMoveOpName) {
104 res.push_back(queue_front);
105 continue;
106 }
107 for (const auto &node_pair : node_users_map[queue_front.first]) {
108 if (std::find(visited.begin(), visited.end(), node_pair.first) != visited.end()) {
109 continue;
110 }
111 anf_queue.push(node_pair);
112 visited.push_back(node_pair.first);
113 }
114 }
115 return res;
116 }
117 } // namespace
118
GetMomentumVarByAccum(const AnfNodePtr & node,const NodeUsersMap & node_users)119 AnfNodePtr GetMomentumVarByAccum(const AnfNodePtr &node, const NodeUsersMap &node_users) {
120 auto param = node->cast<ParameterPtr>();
121 if (param == nullptr) {
122 return nullptr;
123 }
124
125 auto iter = node_users.find(node);
126 if (iter == node_users.end()) {
127 return nullptr;
128 }
129
130 for (const auto ¶m_user : iter->second) {
131 auto cnode = param_user.first->cast<CNodePtr>();
132 if (cnode == nullptr) {
133 continue;
134 }
135
136 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
137 if (prim == nullptr || prim->name() != prim::kPrimApplyMomentum->name()) {
138 continue;
139 }
140
141 auto accum = cnode->input(2)->cast<ParameterPtr>();
142 if (accum == nullptr) {
143 continue;
144 }
145
146 if (accum->name() == param->name()) {
147 return cnode->input(1);
148 }
149 }
150
151 return nullptr;
152 }
153
SetupStorageFormat(const AnfGraphPtr & anf_graph,const AnfNodePtr & param,const std::shared_ptr<GeTensorDesc> & desc,const std::string & ori_format)154 bool StorageFormatConvertor::SetupStorageFormat(const AnfGraphPtr &anf_graph, const AnfNodePtr ¶m,
155 const std::shared_ptr<GeTensorDesc> &desc,
156 const std::string &ori_format) {
157 MS_EXCEPTION_IF_NULL(anf_graph);
158 MS_EXCEPTION_IF_NULL(param);
159 MS_EXCEPTION_IF_NULL(desc);
160 if (device::ascend::GetFormatMode() == "1" || !IsEnableRefMode()) {
161 MS_LOG(DEBUG) << "Enable format mode or disable ref mode, no need to set storage format";
162 return true;
163 }
164
165 auto param_ptr = param->cast<ParameterPtr>();
166 if (param_ptr != nullptr && param_ptr->param_info() != nullptr &&
167 !param_ptr->param_info()->storage_format().empty()) {
168 std::string store_fmt = param_ptr->param_info()->storage_format();
169 MS_LOG(INFO) << "Update desc format from set format: graph: " << anf_graph->ToString()
170 << ", storage format: " << store_fmt << ", pre param: " << param->DebugString()
171 << ", full name: " << param->ToString();
172 auto format = GetGeFormat(param, store_fmt, desc->GetOriginShape().GetDimNum());
173 UpdateTensorDesc(desc, format);
174 UpdateParameterKernelInfo(param, store_fmt);
175 return true;
176 }
177
178 std::string set_format;
179 if (!InitParameterKernelInfo(param, &set_format)) {
180 MS_LOG(INFO) << "Please attention: init Param kernel info failed.";
181 return false;
182 }
183 if (set_format.empty()) {
184 // The weight change storage format first time.
185 SetStorageFormatFromConfig(anf_graph, param, desc);
186 } else if (IsOneOfHWSpecialFormat(set_format)) {
187 // The weight or data is from other subgraph or pynative node which has been set storage format.
188 MS_LOG(INFO) << "Update desc format from set format: graph: " << anf_graph->ToString()
189 << ", storage format: " << set_format << ", pre param: " << param->DebugString()
190 << ", full name: " << param->ToString();
191 auto format = GetGeFormat(param, set_format, desc->GetOriginShape().GetDimNum());
192 UpdateTensorDesc(desc, format);
193 }
194 return true;
195 }
196
SetStorageFormatFromConfig(const AnfGraphPtr & anf_graph,const AnfNodePtr & param,const std::shared_ptr<GeTensorDesc> & desc)197 void StorageFormatConvertor::SetStorageFormatFromConfig(const AnfGraphPtr &anf_graph, const AnfNodePtr ¶m,
198 const std::shared_ptr<GeTensorDesc> &desc) {
199 MS_EXCEPTION_IF_NULL(anf_graph);
200 MS_EXCEPTION_IF_NULL(param);
201 MS_EXCEPTION_IF_NULL(desc);
202 auto manager = anf_graph->manager();
203 if (!manager) {
204 MS_LOG(WARNING) << "Anf graph: " << anf_graph->ToString() << "'s manager is null. create a new one.";
205 manager = Manage(anf_graph, true);
206 anf_graph->set_manager(manager);
207 }
208 auto output_nodes = GetOutputNodesSkipVirtualNode(manager, param);
209 for (const auto &user_node : output_nodes) {
210 // Step 1: node storage format config
211 auto op_type = common::AnfAlgo::GetCNodeName(user_node.first);
212 auto storage_format_config_opt = StorageFormatConfigRegister::GetInstance().GetStorageFormatConfig(op_type);
213 if (!storage_format_config_opt.has_value()) {
214 continue;
215 }
216 auto &storage_format_config = storage_format_config_opt.value();
217 // Step 2: node user index match
218 auto storage_format_info_opt = storage_format_config.GetStorageFormatInfo(IntToSize(user_node.second));
219 if (!storage_format_info_opt.has_value()) {
220 continue;
221 }
222 // Step 3: check origin shape dims
223 auto &storage_format_info = storage_format_info_opt.value();
224 auto fmt_opt = storage_format_info.func_(user_node.first, desc);
225 if (!fmt_opt.has_value()) {
226 continue;
227 }
228 // Step 4: update desc and param format
229 MS_EXCEPTION_IF_NULL(user_node.first);
230 std::string store_fmt = fmt_opt.value();
231 auto format = GetGeFormat(param, user_node.first, store_fmt, desc->GetOriginShape().GetDimNum());
232 MS_LOG(INFO) << "Update desc format from config, graph: " << anf_graph->ToString()
233 << ", used node: " << user_node.first->DebugString() << ", full name: " << user_node.first->ToString()
234 << ",input idx: " << user_node.second << ", storage format: " << store_fmt
235 << ", pre param: " << param->DebugString() << ", full name: " << param->ToString();
236 UpdateTensorDesc(desc, format);
237 if (!storage_format_info.expand_dims_.empty()) {
238 MS_LOG(INFO) << "Set expand dims rule stub.";
239 // desc->SetExpandDimsRule(storage_format_info.expand_dims_);
240 }
241 UpdateParameterKernelInfo(param, store_fmt);
242 }
243 }
244
UpdateTensorDesc(const std::shared_ptr<GeTensorDesc> & desc,int32_t format)245 void StorageFormatConvertor::UpdateTensorDesc(const std::shared_ptr<GeTensorDesc> &desc, int32_t format) {
246 MS_EXCEPTION_IF_NULL(desc);
247 desc->SetFormat(static_cast<ge::Format>(format));
248 desc->SetShape({});
249 desc->SetPlacement(ge::kPlacementDevice);
250 }
251
InitParameterKernelInfo(const AnfNodePtr & param,std::string * format)252 bool StorageFormatConvertor::InitParameterKernelInfo(const AnfNodePtr ¶m, std::string *format) {
253 // param has default should have kernel info with one output
254 MS_EXCEPTION_IF_NULL(param);
255 MS_EXCEPTION_IF_NULL(format);
256 const auto &output_with_indexes = common::AnfAlgo::GetAllOutputWithIndex(param);
257 if (output_with_indexes.size() != 1) {
258 MS_LOG(ERROR) << "Param: " << param->ToString() << "'s output size is not 1.";
259 return false;
260 }
261 std::shared_ptr<device::KernelInfo> kernel_info =
262 std::dynamic_pointer_cast<device::KernelInfo>(param->kernel_info_ptr());
263 if (!kernel_info) {
264 // create parameter node should create kernel info
265 MS_LOG(INFO) << "Please attention, param: " << param->ToString() << "don't have kernel info.";
266 return false;
267 }
268 kernel::KernelBuildInfoPtr build_info = kernel_info->GetMutableSelectKernelBuildInfo();
269 if (build_info && build_info->GetOutputDeviceType(0) != kTypeUnknown) {
270 (*format) = build_info->GetOutputFormat(0);
271 MS_LOG(INFO) << "Param: " << param->ToString() << " node has been setup, build info: " << build_info->ToString();
272 return true;
273 }
274
275 if (!build_info) {
276 MS_LOG(ERROR) << "Param: " << param->ToString() << " build info is null.";
277 return false;
278 }
279
280 std::vector<TypeId> output_infer_types;
281 std::vector<std::string> output_formats;
282 (void)output_infer_types.emplace_back(common::AnfAlgo::GetOutputInferDataType(param, 0));
283 (void)output_formats.emplace_back(kOpFormat_DEFAULT);
284 build_info->SetOutputsDeviceType(output_infer_types);
285 build_info->SetOutputsFormat(output_formats);
286 kernel_info->set_select_kernel_build_info(build_info);
287 return true;
288 }
289
UpdateParameterKernelInfo(const AnfNodePtr & param,const std::string & format)290 void StorageFormatConvertor::UpdateParameterKernelInfo(const AnfNodePtr ¶m, const std::string &format) {
291 // param has default should have kernel info with one output
292 MS_EXCEPTION_IF_NULL(param);
293 std::shared_ptr<device::KernelInfo> kernel_info =
294 std::dynamic_pointer_cast<device::KernelInfo>(param->kernel_info_ptr());
295 MS_EXCEPTION_IF_NULL(kernel_info);
296 kernel::KernelBuildInfoPtr build_info = kernel_info->GetMutableSelectKernelBuildInfo();
297 MS_EXCEPTION_IF_NULL(build_info);
298 build_info->SetOutputsFormat({format});
299 kernel_info->set_select_kernel_build_info(build_info);
300 }
301
GetGeFormat(const AnfNodePtr & src_node,const AnfNodePtr & dst_node,const std::string & storage_format,size_t origin_dim)302 int32_t StorageFormatConvertor::GetGeFormat(const AnfNodePtr &src_node, const AnfNodePtr &dst_node,
303 const std::string &storage_format, size_t origin_dim) {
304 MS_EXCEPTION_IF_NULL(src_node);
305 MS_EXCEPTION_IF_NULL(dst_node);
306 int64_t groups = 0;
307 auto param = src_node->cast<ParameterPtr>();
308 MS_EXCEPTION_IF_NULL(param);
309 auto cnode = dst_node->cast<CNodePtr>();
310 MS_EXCEPTION_IF_NULL(cnode);
311 if (common::AnfAlgo::HasNodeAttr(kAttrGroups, cnode)) {
312 groups = common::AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrGroups);
313 param->set_fracz_group(groups);
314 }
315 auto primary_format = TransformUtil::ConvertFormat(storage_format, origin_dim);
316 auto format = ::ge::GetFormatFromSub(static_cast<int32_t>(primary_format), LongToInt(groups));
317 return format;
318 }
319
GetGeFormat(const AnfNodePtr & src_node,const std::string & storage_format,size_t origin_dim)320 int32_t StorageFormatConvertor::GetGeFormat(const AnfNodePtr &src_node, const std::string &storage_format,
321 size_t origin_dim) {
322 MS_EXCEPTION_IF_NULL(src_node);
323 auto param = src_node->cast<ParameterPtr>();
324 MS_EXCEPTION_IF_NULL(param);
325 auto primary_format = TransformUtil::ConvertFormat(storage_format, origin_dim);
326 auto format = ::ge::GetFormatFromSub(static_cast<int32_t>(primary_format), LongToInt(param->fracz_group()));
327 return format;
328 }
329 } // namespace mindspore::transform
330