1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/graph/special_node_postprocess.h"
19 #include <memory>
20 #include <utility>
21 #include <vector>
22 #include "mindspore/core/ops/nn_ops.h"
23 #include "mindspore/core/ops/lite_ops.h"
24 #include "mindspore/core/ops/framework_ops.h"
25 #include "include/errorcode.h"
26 #include "tools/optimizer/common/format_utils.h"
27 #include "nnacl/op_base.h"
28 #include "ops/op_utils.h"
29
30 namespace mindspore {
31 namespace opt {
32 namespace {
33 const PrimitivePtr kPrimInstanceNorm = std::make_shared<Primitive>("InstanceNorm");
GenerateNewShape(const abstract::AbstractBasePtr & abstract)34 ShapeVector GenerateNewShape(const abstract::AbstractBasePtr &abstract) {
35 MS_ASSERT(abstract != nullptr);
36 ShapeVector shape;
37 if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
38 return shape;
39 }
40 if (shape.size() == kInputSizeFour) {
41 ShapeVector real_shape = {shape[0], shape[kInputIndexThree], shape[1], shape[kInputIndexTwo]};
42 shape = real_shape;
43 }
44 return shape;
45 }
46 } // namespace
47
Run(const FuncGraphPtr & func_graph)48 bool SpecialNodePostProcess::Run(const FuncGraphPtr &func_graph) {
49 MS_ASSERT(func_graph != nullptr);
50 auto manager = Manage(func_graph, true);
51 if (manager == nullptr) {
52 MS_LOG(ERROR) << "manager is nullptr.";
53 return false;
54 }
55 auto node_list = TopoSort(func_graph->get_return());
56 for (auto &node : node_list) {
57 if (!utils::isa<CNode>(node)) {
58 continue;
59 }
60 auto cnode = node->cast<CNodePtr>();
61 if (CheckPrimitiveType(cnode, prim::kPrimIf) || CheckPrimitiveType(cnode, prim::kPrimWhile)) {
62 auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
63 if (sub_func_graph == nullptr) {
64 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
65 return false;
66 }
67 if (!Run(sub_func_graph)) {
68 MS_LOG(ERROR) << "postprocess for handling special node failed.";
69 return false;
70 }
71 if (sub_func_graph == nullptr) {
72 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
73 return false;
74 }
75 if (!Run(sub_func_graph)) {
76 MS_LOG(ERROR) << "postprocess for handling special node failed.";
77 return false;
78 }
79 continue;
80 }
81 if (!CheckInstanceNorm(func_graph, cnode)) {
82 continue;
83 }
84 if (HandleInstanceNorm(func_graph, cnode) != lite::RET_OK) {
85 MS_LOG(ERROR) << "post-process instance_norm failed.";
86 return false;
87 }
88 }
89 return true;
90 }
91
CheckInstanceNorm(const FuncGraphPtr & func_graph,const CNodePtr & cnode)92 bool SpecialNodePostProcess::CheckInstanceNorm(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
93 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
94 if (!CheckPrimitiveType(cnode, kPrimInstanceNorm)) {
95 return false;
96 }
97 auto manager = func_graph->manager();
98 MS_ASSERT(manager != nullptr);
99 auto pre_node = cnode->input(1);
100 if (!CheckPrimitiveType(pre_node, prim::kPrimConv2DFusion) && !CheckPrimitiveType(pre_node, prim::kPrimActivation)) {
101 return true;
102 }
103 if (!utils::isa<CNode>(pre_node)) {
104 return true;
105 }
106 std::vector<AnfNodePtr> pre_nodes;
107 pre_nodes.push_back(pre_node);
108 if (CheckPrimitiveType(pre_node, prim::kPrimActivation)) {
109 pre_node = pre_node->cast<CNodePtr>()->input(1);
110 if (!utils::isa<CNode>(pre_node) || !CheckPrimitiveType(pre_node, prim::kPrimConv2DFusion)) {
111 return true;
112 }
113 pre_nodes.push_back(pre_node);
114 }
115 bool is_nc = false;
116 for (const auto &node : pre_nodes) {
117 auto node_users = manager->node_users()[node];
118 is_nc = is_nc || std::any_of(node_users.begin(), node_users.end(), [](const std::pair<AnfNodePtr, int> &node_user) {
119 return !CheckPrimitiveType(node_user.first, kPrimInstanceNorm);
120 });
121 }
122 return is_nc;
123 }
124
HandleInstanceNorm(const FuncGraphPtr & func_graph,const CNodePtr & cnode)125 int SpecialNodePostProcess::HandleInstanceNorm(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
126 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
127 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
128 MS_CHECK_TRUE_RET(prim != nullptr, lite::RET_ERROR);
129 if (prim->GetAttr(ops::kFormat) == nullptr) {
130 MS_LOG(ERROR) << "The node should have format attribute.";
131 return lite::RET_ERROR;
132 }
133 auto format = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
134 if (format == mindspore::NCHW) {
135 return lite::RET_OK;
136 }
137 if (format != mindspore::NHWC) {
138 MS_LOG(ERROR) << "format attribute is invalid.";
139 return lite::RET_ERROR;
140 }
141 auto manager = func_graph->manager();
142 MS_ASSERT(manager != nullptr);
143 auto pre_transpose =
144 GenTransposeNode(func_graph, cnode->input(1), kNH2NC, cnode->fullname_with_scope() + "_pre_nh2nc");
145 MS_CHECK_TRUE_RET(pre_transpose != nullptr, lite::RET_ERROR);
146 auto pre_trans_prim = GetValueNode<PrimitivePtr>(pre_transpose->input(0));
147 MS_CHECK_TRUE_RET(pre_trans_prim != nullptr, lite::RET_ERROR);
148 (void)pre_trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(mindspore::NHWC));
149 auto abstract = GetCNodeInputAbstract(cnode, 1);
150 if (abstract != nullptr) {
151 auto shape = GenerateNewShape(abstract);
152 auto pre_trans_abstract = abstract->Clone();
153 pre_trans_abstract->set_shape(std::make_shared<abstract::Shape>(shape));
154 pre_transpose->set_abstract(pre_trans_abstract);
155 }
156 manager->SetEdge(cnode, 1, pre_transpose);
157 auto post_transpose = GenTransposeNode(func_graph, cnode, kNC2NH, cnode->fullname_with_scope() + "_post_nc2nh");
158 MS_CHECK_TRUE_RET(post_transpose != nullptr, lite::RET_ERROR);
159 auto post_trans_prim = GetValueNode<PrimitivePtr>(post_transpose->input(0));
160 MS_CHECK_TRUE_RET(post_trans_prim != nullptr, lite::RET_ERROR);
161 (void)post_trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(mindspore::NCHW));
162 (void)prim->AddAttr(ops::kFormat, MakeValue<int64_t>(mindspore::NCHW));
163 abstract = cnode->abstract();
164 if (abstract != nullptr) {
165 post_transpose->set_abstract(abstract->Clone());
166 auto shape = GenerateNewShape(abstract);
167 abstract->set_shape(std::make_shared<abstract::Shape>(shape));
168 }
169 (void)manager->Replace(cnode, post_transpose);
170 return lite::RET_OK;
171 }
172 } // namespace opt
173 } // namespace mindspore
174