• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/graph/special_node_postprocess.h"
19 #include <memory>
20 #include <utility>
21 #include <vector>
22 #include "mindspore/core/ops/nn_ops.h"
23 #include "mindspore/core/ops/lite_ops.h"
24 #include "mindspore/core/ops/framework_ops.h"
25 #include "include/errorcode.h"
26 #include "tools/optimizer/common/format_utils.h"
27 #include "nnacl/op_base.h"
28 #include "ops/op_utils.h"
29 
30 namespace mindspore {
31 namespace opt {
32 namespace {
33 const PrimitivePtr kPrimInstanceNorm = std::make_shared<Primitive>("InstanceNorm");
GenerateNewShape(const abstract::AbstractBasePtr & abstract)34 ShapeVector GenerateNewShape(const abstract::AbstractBasePtr &abstract) {
35   MS_ASSERT(abstract != nullptr);
36   ShapeVector shape;
37   if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
38     return shape;
39   }
40   if (shape.size() == kInputSizeFour) {
41     ShapeVector real_shape = {shape[0], shape[kInputIndexThree], shape[1], shape[kInputIndexTwo]};
42     shape = real_shape;
43   }
44   return shape;
45 }
46 }  // namespace
47 
Run(const FuncGraphPtr & func_graph)48 bool SpecialNodePostProcess::Run(const FuncGraphPtr &func_graph) {
49   MS_ASSERT(func_graph != nullptr);
50   auto manager = Manage(func_graph, true);
51   if (manager == nullptr) {
52     MS_LOG(ERROR) << "manager is nullptr.";
53     return false;
54   }
55   auto node_list = TopoSort(func_graph->get_return());
56   for (auto &node : node_list) {
57     if (!utils::isa<CNode>(node)) {
58       continue;
59     }
60     auto cnode = node->cast<CNodePtr>();
61     if (CheckPrimitiveType(cnode, prim::kPrimIf) || CheckPrimitiveType(cnode, prim::kPrimWhile)) {
62       auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
63       if (sub_func_graph == nullptr) {
64         lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
65         return false;
66       }
67       if (!Run(sub_func_graph)) {
68         MS_LOG(ERROR) << "postprocess for handling special node failed.";
69         return false;
70       }
71       if (sub_func_graph == nullptr) {
72         lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
73         return false;
74       }
75       if (!Run(sub_func_graph)) {
76         MS_LOG(ERROR) << "postprocess for handling special node failed.";
77         return false;
78       }
79       continue;
80     }
81     if (!CheckInstanceNorm(func_graph, cnode)) {
82       continue;
83     }
84     if (HandleInstanceNorm(func_graph, cnode) != lite::RET_OK) {
85       MS_LOG(ERROR) << "post-process instance_norm failed.";
86       return false;
87     }
88   }
89   return true;
90 }
91 
CheckInstanceNorm(const FuncGraphPtr & func_graph,const CNodePtr & cnode)92 bool SpecialNodePostProcess::CheckInstanceNorm(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
93   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
94   if (!CheckPrimitiveType(cnode, kPrimInstanceNorm)) {
95     return false;
96   }
97   auto manager = func_graph->manager();
98   MS_ASSERT(manager != nullptr);
99   auto pre_node = cnode->input(1);
100   if (!CheckPrimitiveType(pre_node, prim::kPrimConv2DFusion) && !CheckPrimitiveType(pre_node, prim::kPrimActivation)) {
101     return true;
102   }
103   if (!utils::isa<CNode>(pre_node)) {
104     return true;
105   }
106   std::vector<AnfNodePtr> pre_nodes;
107   pre_nodes.push_back(pre_node);
108   if (CheckPrimitiveType(pre_node, prim::kPrimActivation)) {
109     pre_node = pre_node->cast<CNodePtr>()->input(1);
110     if (!utils::isa<CNode>(pre_node) || !CheckPrimitiveType(pre_node, prim::kPrimConv2DFusion)) {
111       return true;
112     }
113     pre_nodes.push_back(pre_node);
114   }
115   bool is_nc = false;
116   for (const auto &node : pre_nodes) {
117     auto node_users = manager->node_users()[node];
118     is_nc = is_nc || std::any_of(node_users.begin(), node_users.end(), [](const std::pair<AnfNodePtr, int> &node_user) {
119               return !CheckPrimitiveType(node_user.first, kPrimInstanceNorm);
120             });
121   }
122   return is_nc;
123 }
124 
HandleInstanceNorm(const FuncGraphPtr & func_graph,const CNodePtr & cnode)125 int SpecialNodePostProcess::HandleInstanceNorm(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
126   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
127   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
128   MS_CHECK_TRUE_RET(prim != nullptr, lite::RET_ERROR);
129   if (prim->GetAttr(ops::kFormat) == nullptr) {
130     MS_LOG(ERROR) << "The node should have format attribute.";
131     return lite::RET_ERROR;
132   }
133   auto format = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
134   if (format == mindspore::NCHW) {
135     return lite::RET_OK;
136   }
137   if (format != mindspore::NHWC) {
138     MS_LOG(ERROR) << "format attribute is invalid.";
139     return lite::RET_ERROR;
140   }
141   auto manager = func_graph->manager();
142   MS_ASSERT(manager != nullptr);
143   auto pre_transpose =
144     GenTransposeNode(func_graph, cnode->input(1), kNH2NC, cnode->fullname_with_scope() + "_pre_nh2nc");
145   MS_CHECK_TRUE_RET(pre_transpose != nullptr, lite::RET_ERROR);
146   auto pre_trans_prim = GetValueNode<PrimitivePtr>(pre_transpose->input(0));
147   MS_CHECK_TRUE_RET(pre_trans_prim != nullptr, lite::RET_ERROR);
148   (void)pre_trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(mindspore::NHWC));
149   auto abstract = GetCNodeInputAbstract(cnode, 1);
150   if (abstract != nullptr) {
151     auto shape = GenerateNewShape(abstract);
152     auto pre_trans_abstract = abstract->Clone();
153     pre_trans_abstract->set_shape(std::make_shared<abstract::Shape>(shape));
154     pre_transpose->set_abstract(pre_trans_abstract);
155   }
156   manager->SetEdge(cnode, 1, pre_transpose);
157   auto post_transpose = GenTransposeNode(func_graph, cnode, kNC2NH, cnode->fullname_with_scope() + "_post_nc2nh");
158   MS_CHECK_TRUE_RET(post_transpose != nullptr, lite::RET_ERROR);
159   auto post_trans_prim = GetValueNode<PrimitivePtr>(post_transpose->input(0));
160   MS_CHECK_TRUE_RET(post_trans_prim != nullptr, lite::RET_ERROR);
161   (void)post_trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(mindspore::NCHW));
162   (void)prim->AddAttr(ops::kFormat, MakeValue<int64_t>(mindspore::NCHW));
163   abstract = cnode->abstract();
164   if (abstract != nullptr) {
165     post_transpose->set_abstract(abstract->Clone());
166     auto shape = GenerateNewShape(abstract);
167     abstract->set_shape(std::make_shared<abstract::Shape>(shape));
168   }
169   (void)manager->Replace(cnode, post_transpose);
170   return lite::RET_OK;
171 }
172 }  // namespace opt
173 }  // namespace mindspore
174