• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "transform/graph_ir/storage_format_convertor.h"
18 
19 #include <queue>
20 #include <utility>
21 #include <vector>
22 #include <memory>
23 #include <string>
24 #include "graph/types.h"
25 #include "ops/conv_pool_ops.h"
26 #include "transform/graph_ir/storage_format_config_factory.h"
27 #include "ir/func_graph.h"
28 #include "include/common/utils/anfalgo.h"
29 #include "include/backend/kernel_info.h"
30 #include "transform/graph_ir/transform_util.h"
31 #include "plugin/device/ascend/hal/common/ascend_utils.h"
32 #include "ops/framework_op_name.h"
33 #include "ops/framework_ops.h"
34 #include "ops/nn_optimizer_ops.h"
35 
36 namespace mindspore::transform {
37 namespace {
GetUsedOperator(const AnfNodePtr & node,const NodeUsersMap & node_users,const PrimitivePtr & prim)38 AnfNodePtr GetUsedOperator(const AnfNodePtr &node, const NodeUsersMap &node_users, const PrimitivePtr &prim) {
39   auto iter = node_users.find(node);
40   if (iter != node_users.end()) {
41     for (const auto &node_user : iter->second) {
42       if (common::AnfAlgo::GetCNodeName(node_user.first) == prim->name()) {
43         return node_user.first;
44       }
45     }
46   }
47   return nullptr;
48 }
49 
IsUsedByConv2D(const AnfNodePtr & node,const NodeUsersMap & node_users)50 bool IsUsedByConv2D(const AnfNodePtr &node, const NodeUsersMap &node_users) {
51   auto load = GetUsedOperator(node, node_users, prim::kPrimLoad);
52   if (load == nullptr) {
53     return false;
54   }
55   if (GetUsedOperator(load, node_users, prim::kPrimConv2D) != nullptr) {
56     return true;
57   }
58   auto cast = GetUsedOperator(load, node_users, prim::kPrimCast);
59   if (cast == nullptr) {
60     return false;
61   }
62   return GetUsedOperator(cast, node_users, prim::kPrimConv2D) != nullptr;
63 }
64 
IsUsedBySwitch(const AnfNodePtr & node,const NodeUsersMap & node_users)65 bool IsUsedBySwitch(const AnfNodePtr &node, const NodeUsersMap &node_users) {
66   if (common::AnfAlgo::GetCNodeName(node) != prim::kPrimPartial->name()) {
67     return false;
68   }
69 
70   if (GetUsedOperator(node, node_users, prim::kPrimSwitch) != nullptr) {
71     return true;
72   }
73 
74   auto make_tuple = GetUsedOperator(node, node_users, prim::kPrimMakeTuple);
75   if (make_tuple != nullptr && GetUsedOperator(make_tuple, node_users, prim::kPrimSwitchLayer) != nullptr) {
76     return true;
77   }
78 
79   return false;
80 }
81 
GetOutputNodesSkipVirtualNode(const FuncGraphManagerPtr & manager,const AnfNodePtr & node)82 std::vector<std::pair<AnfNodePtr, int>> GetOutputNodesSkipVirtualNode(const FuncGraphManagerPtr &manager,
83                                                                       const AnfNodePtr &node) {
84   std::vector<std::pair<AnfNodePtr, int>> res;
85   std::queue<std::pair<AnfNodePtr, int>> anf_queue;
86   std::vector<AnfNodePtr> visited;
87   MS_EXCEPTION_IF_NULL(manager);
88   auto node_users_map = manager->node_users();
89   for (const auto &node_pair : node_users_map[node]) {
90     anf_queue.push(node_pair);
91     visited.push_back(node_pair.first);
92   }
93   while (!anf_queue.empty()) {
94     auto queue_front = anf_queue.front();
95     anf_queue.pop();
96     // NOTE fix: do not support trans from NC1HWC0 to ND between parameter and Switch-op/switch_layer-op
97     auto momentum_var = GetMomentumVarByAccum(node, node_users_map);
98     if (IsUsedBySwitch(queue_front.first, node_users_map) && (momentum_var != nullptr) &&
99         !IsUsedByConv2D(momentum_var, node_users_map)) {
100       return {};
101     }
102     std::string op_name = common::AnfAlgo::GetCNodeName(queue_front.first);
103     if (AnfUtils::IsRealKernel(queue_front.first) && op_name != kCastOpName && op_name != kTensorMoveOpName) {
104       res.push_back(queue_front);
105       continue;
106     }
107     for (const auto &node_pair : node_users_map[queue_front.first]) {
108       if (std::find(visited.begin(), visited.end(), node_pair.first) != visited.end()) {
109         continue;
110       }
111       anf_queue.push(node_pair);
112       visited.push_back(node_pair.first);
113     }
114   }
115   return res;
116 }
117 }  // namespace
118 
GetMomentumVarByAccum(const AnfNodePtr & node,const NodeUsersMap & node_users)119 AnfNodePtr GetMomentumVarByAccum(const AnfNodePtr &node, const NodeUsersMap &node_users) {
120   auto param = node->cast<ParameterPtr>();
121   if (param == nullptr) {
122     return nullptr;
123   }
124 
125   auto iter = node_users.find(node);
126   if (iter == node_users.end()) {
127     return nullptr;
128   }
129 
130   for (const auto &param_user : iter->second) {
131     auto cnode = param_user.first->cast<CNodePtr>();
132     if (cnode == nullptr) {
133       continue;
134     }
135 
136     auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
137     if (prim == nullptr || prim->name() != prim::kPrimApplyMomentum->name()) {
138       continue;
139     }
140 
141     auto accum = cnode->input(2)->cast<ParameterPtr>();
142     if (accum == nullptr) {
143       continue;
144     }
145 
146     if (accum->name() == param->name()) {
147       return cnode->input(1);
148     }
149   }
150 
151   return nullptr;
152 }
153 
SetupStorageFormat(const AnfGraphPtr & anf_graph,const AnfNodePtr & param,const std::shared_ptr<GeTensorDesc> & desc,const std::string & ori_format)154 bool StorageFormatConvertor::SetupStorageFormat(const AnfGraphPtr &anf_graph, const AnfNodePtr &param,
155                                                 const std::shared_ptr<GeTensorDesc> &desc,
156                                                 const std::string &ori_format) {
157   MS_EXCEPTION_IF_NULL(anf_graph);
158   MS_EXCEPTION_IF_NULL(param);
159   MS_EXCEPTION_IF_NULL(desc);
160   if (device::ascend::GetFormatMode() == "1" || !IsEnableRefMode()) {
161     MS_LOG(DEBUG) << "Enable format mode or disable ref mode, no need to set storage format";
162     return true;
163   }
164 
165   auto param_ptr = param->cast<ParameterPtr>();
166   if (param_ptr != nullptr && param_ptr->param_info() != nullptr &&
167       !param_ptr->param_info()->storage_format().empty()) {
168     std::string store_fmt = param_ptr->param_info()->storage_format();
169     MS_LOG(INFO) << "Update desc format from set format: graph: " << anf_graph->ToString()
170                  << ", storage format: " << store_fmt << ", pre param: " << param->DebugString()
171                  << ", full name: " << param->ToString();
172     auto format = GetGeFormat(param, store_fmt, desc->GetOriginShape().GetDimNum());
173     UpdateTensorDesc(desc, format);
174     UpdateParameterKernelInfo(param, store_fmt);
175     return true;
176   }
177 
178   std::string set_format;
179   if (!InitParameterKernelInfo(param, &set_format)) {
180     MS_LOG(INFO) << "Please attention: init Param kernel info failed.";
181     return false;
182   }
183   if (set_format.empty()) {
184     // The weight change storage format first time.
185     SetStorageFormatFromConfig(anf_graph, param, desc);
186   } else if (IsOneOfHWSpecialFormat(set_format)) {
187     // The weight or data is from other subgraph or pynative node which has been set storage format.
188     MS_LOG(INFO) << "Update desc format from set format: graph: " << anf_graph->ToString()
189                  << ", storage format: " << set_format << ", pre param: " << param->DebugString()
190                  << ", full name: " << param->ToString();
191     auto format = GetGeFormat(param, set_format, desc->GetOriginShape().GetDimNum());
192     UpdateTensorDesc(desc, format);
193   }
194   return true;
195 }
196 
SetStorageFormatFromConfig(const AnfGraphPtr & anf_graph,const AnfNodePtr & param,const std::shared_ptr<GeTensorDesc> & desc)197 void StorageFormatConvertor::SetStorageFormatFromConfig(const AnfGraphPtr &anf_graph, const AnfNodePtr &param,
198                                                         const std::shared_ptr<GeTensorDesc> &desc) {
199   MS_EXCEPTION_IF_NULL(anf_graph);
200   MS_EXCEPTION_IF_NULL(param);
201   MS_EXCEPTION_IF_NULL(desc);
202   auto manager = anf_graph->manager();
203   if (!manager) {
204     MS_LOG(WARNING) << "Anf graph: " << anf_graph->ToString() << "'s manager is null. create a new one.";
205     manager = Manage(anf_graph, true);
206     anf_graph->set_manager(manager);
207   }
208   auto output_nodes = GetOutputNodesSkipVirtualNode(manager, param);
209   for (const auto &user_node : output_nodes) {
210     // Step 1: node storage format config
211     auto op_type = common::AnfAlgo::GetCNodeName(user_node.first);
212     auto storage_format_config_opt = StorageFormatConfigRegister::GetInstance().GetStorageFormatConfig(op_type);
213     if (!storage_format_config_opt.has_value()) {
214       continue;
215     }
216     auto &storage_format_config = storage_format_config_opt.value();
217     // Step 2: node user index match
218     auto storage_format_info_opt = storage_format_config.GetStorageFormatInfo(IntToSize(user_node.second));
219     if (!storage_format_info_opt.has_value()) {
220       continue;
221     }
222     // Step 3: check origin shape dims
223     auto &storage_format_info = storage_format_info_opt.value();
224     auto fmt_opt = storage_format_info.func_(user_node.first, desc);
225     if (!fmt_opt.has_value()) {
226       continue;
227     }
228     // Step 4: update desc and param format
229     MS_EXCEPTION_IF_NULL(user_node.first);
230     std::string store_fmt = fmt_opt.value();
231     auto format = GetGeFormat(param, user_node.first, store_fmt, desc->GetOriginShape().GetDimNum());
232     MS_LOG(INFO) << "Update desc format from config, graph: " << anf_graph->ToString()
233                  << ", used node: " << user_node.first->DebugString() << ", full name: " << user_node.first->ToString()
234                  << ",input idx: " << user_node.second << ", storage format: " << store_fmt
235                  << ", pre param: " << param->DebugString() << ", full name: " << param->ToString();
236     UpdateTensorDesc(desc, format);
237     if (!storage_format_info.expand_dims_.empty()) {
238       MS_LOG(INFO) << "Set expand dims rule stub.";
239       // desc->SetExpandDimsRule(storage_format_info.expand_dims_);
240     }
241     UpdateParameterKernelInfo(param, store_fmt);
242   }
243 }
244 
UpdateTensorDesc(const std::shared_ptr<GeTensorDesc> & desc,int32_t format)245 void StorageFormatConvertor::UpdateTensorDesc(const std::shared_ptr<GeTensorDesc> &desc, int32_t format) {
246   MS_EXCEPTION_IF_NULL(desc);
247   desc->SetFormat(static_cast<ge::Format>(format));
248   desc->SetShape({});
249   desc->SetPlacement(ge::kPlacementDevice);
250 }
251 
InitParameterKernelInfo(const AnfNodePtr & param,std::string * format)252 bool StorageFormatConvertor::InitParameterKernelInfo(const AnfNodePtr &param, std::string *format) {
253   // param has default should have kernel info with one output
254   MS_EXCEPTION_IF_NULL(param);
255   MS_EXCEPTION_IF_NULL(format);
256   const auto &output_with_indexes = common::AnfAlgo::GetAllOutputWithIndex(param);
257   if (output_with_indexes.size() != 1) {
258     MS_LOG(ERROR) << "Param: " << param->ToString() << "'s output size is not 1.";
259     return false;
260   }
261   std::shared_ptr<device::KernelInfo> kernel_info =
262     std::dynamic_pointer_cast<device::KernelInfo>(param->kernel_info_ptr());
263   if (!kernel_info) {
264     // create parameter node should create kernel info
265     MS_LOG(INFO) << "Please attention, param: " << param->ToString() << "don't have kernel info.";
266     return false;
267   }
268   kernel::KernelBuildInfoPtr build_info = kernel_info->GetMutableSelectKernelBuildInfo();
269   if (build_info && build_info->GetOutputDeviceType(0) != kTypeUnknown) {
270     (*format) = build_info->GetOutputFormat(0);
271     MS_LOG(INFO) << "Param: " << param->ToString() << " node has been setup, build info: " << build_info->ToString();
272     return true;
273   }
274 
275   if (!build_info) {
276     MS_LOG(ERROR) << "Param: " << param->ToString() << " build info is null.";
277     return false;
278   }
279 
280   std::vector<TypeId> output_infer_types;
281   std::vector<std::string> output_formats;
282   (void)output_infer_types.emplace_back(common::AnfAlgo::GetOutputInferDataType(param, 0));
283   (void)output_formats.emplace_back(kOpFormat_DEFAULT);
284   build_info->SetOutputsDeviceType(output_infer_types);
285   build_info->SetOutputsFormat(output_formats);
286   kernel_info->set_select_kernel_build_info(build_info);
287   return true;
288 }
289 
UpdateParameterKernelInfo(const AnfNodePtr & param,const std::string & format)290 void StorageFormatConvertor::UpdateParameterKernelInfo(const AnfNodePtr &param, const std::string &format) {
291   // param has default should have kernel info with one output
292   MS_EXCEPTION_IF_NULL(param);
293   std::shared_ptr<device::KernelInfo> kernel_info =
294     std::dynamic_pointer_cast<device::KernelInfo>(param->kernel_info_ptr());
295   MS_EXCEPTION_IF_NULL(kernel_info);
296   kernel::KernelBuildInfoPtr build_info = kernel_info->GetMutableSelectKernelBuildInfo();
297   MS_EXCEPTION_IF_NULL(build_info);
298   build_info->SetOutputsFormat({format});
299   kernel_info->set_select_kernel_build_info(build_info);
300 }
301 
GetGeFormat(const AnfNodePtr & src_node,const AnfNodePtr & dst_node,const std::string & storage_format,size_t origin_dim)302 int32_t StorageFormatConvertor::GetGeFormat(const AnfNodePtr &src_node, const AnfNodePtr &dst_node,
303                                             const std::string &storage_format, size_t origin_dim) {
304   MS_EXCEPTION_IF_NULL(src_node);
305   MS_EXCEPTION_IF_NULL(dst_node);
306   int64_t groups = 0;
307   auto param = src_node->cast<ParameterPtr>();
308   MS_EXCEPTION_IF_NULL(param);
309   auto cnode = dst_node->cast<CNodePtr>();
310   MS_EXCEPTION_IF_NULL(cnode);
311   if (common::AnfAlgo::HasNodeAttr(kAttrGroups, cnode)) {
312     groups = common::AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrGroups);
313     param->set_fracz_group(groups);
314   }
315   auto primary_format = TransformUtil::ConvertFormat(storage_format, origin_dim);
316   auto format = ::ge::GetFormatFromSub(static_cast<int32_t>(primary_format), LongToInt(groups));
317   return format;
318 }
319 
GetGeFormat(const AnfNodePtr & src_node,const std::string & storage_format,size_t origin_dim)320 int32_t StorageFormatConvertor::GetGeFormat(const AnfNodePtr &src_node, const std::string &storage_format,
321                                             size_t origin_dim) {
322   MS_EXCEPTION_IF_NULL(src_node);
323   auto param = src_node->cast<ParameterPtr>();
324   MS_EXCEPTION_IF_NULL(param);
325   auto primary_format = TransformUtil::ConvertFormat(storage_format, origin_dim);
326   auto format = ::ge::GetFormatFromSub(static_cast<int32_t>(primary_format), LongToInt(param->fracz_group()));
327   return format;
328 }
329 }  // namespace mindspore::transform
330