1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2024 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "pipeline/jit/ps/validator.h"
20
21 #include <memory>
22 #include <mutex>
23 #include <string>
24
25 #include "mindspore/core/ops/structure_ops.h"
26 #include "mindspore/core/ops/sequence_ops.h"
27 #include "mindspore/core/ops/other_ops.h"
28 #include "mindspore/core/ops/nn_optimizer_ops.h"
29 #include "mindspore/core/ops/framework_ops.h"
30 #include "ir/manager.h"
31 #include "ir/dtype.h"
32 #include "pipeline/jit/ps/static_analysis/prim.h"
33 #include "pipeline/jit/ps/parse/resolve.h"
34 #include "pipeline/jit/ps/debug/trace.h"
35
36 namespace mindspore {
37 namespace validator {
38 using mindspore::abstract::AbstractBase;
39 using mindspore::abstract::AbstractFunction;
40 using mindspore::abstract::AbstractJTagged;
41 using mindspore::abstract::AbstractList;
42 using mindspore::abstract::AbstractMapTensor;
43 using mindspore::abstract::AbstractProblem;
44 using mindspore::abstract::AbstractRefTensor;
45 using mindspore::abstract::AbstractRowTensor;
46 using mindspore::abstract::AbstractScalar;
47 using mindspore::abstract::AbstractSequence;
48 using mindspore::abstract::AbstractTensor;
49 using mindspore::abstract::AbstractTuple;
50 using mindspore::abstract::AbstractType;
51
ValidateOperation(const AnfNodePtr & node)52 void ValidateOperation(const AnfNodePtr &node) {
53 if (!IsValueNode<Primitive>(node)) {
54 return;
55 }
56
57 // Primitive must in whitelist
58 auto prim = GetValueNode<PrimitivePtr>(node);
59 MS_EXCEPTION_IF_NULL(prim);
60 if (prim->isa<prim::DoSignaturePrimitive>()) {
61 MS_LOG(INTERNAL_EXCEPTION) << "Illegal DoSignaturePrimitive '" << prim->name() << "' in the graph."
62 << "node:" << node->DebugString()
63 << ", location:" << trace::GetDebugInfoStr(node->debug_info());
64 }
65 if (abstract::IsInWhiteList(prim)) {
66 return;
67 }
68 if (prim->HasAttr("is_load")) {
69 return;
70 }
71 if (prim->name() == "PyExecute") {
72 return;
73 }
74 if (prim->name() == "TensorMove") {
75 return;
76 }
77
78 if (prim->isa<PrimitivePy>()) {
79 MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python evaluator.";
80 return;
81 }
82 if (prim->name() == "fake_bprop") {
83 MS_LOG(INTERNAL_EXCEPTION) << "Illegal primitive: " << GetValue<std::string>(prim->GetAttr("info"))
84 << "node:" << node->DebugString()
85 << ", location:" << trace::GetDebugInfoStr(node->debug_info());
86 }
87
88 MS_LOG(EXCEPTION) << "Illegal primitive: " << prim->name()
89 << ". Please check whether to use unsupported primitive:" << node->DebugString()
90 << ", location:" << trace::GetDebugInfoStr(node->debug_info());
91 }
92
CheckAbstractScalar(const AnfNodePtr & node)93 bool CheckAbstractScalar(const AnfNodePtr &node) {
94 MS_EXCEPTION_IF_NULL(node);
95 AbstractBasePtr abstract = node->abstract();
96 if (abstract->isa<AbstractScalar>()) {
97 TypePtr type = abstract->GetTypeTrack();
98 MS_EXCEPTION_IF_NULL(type);
99 if (type->isa<EnvType>() || type->isa<MsClassType>()) {
100 MS_LOG(EXCEPTION) << "Illegal type in the graph: " << abstract->ToString() << ", node: " << node->DebugString()
101 << ", location:" << trace::GetDebugInfoStr(node->debug_info());
102 }
103 auto real_node = node;
104 if (IsPrimitiveCNode(node, prim::kPrimReturn) || IsPrimitiveCNode(node, prim::kPrimDepend)) {
105 real_node = real_node->cast<CNodePtr>()->input(1);
106 }
107 // Only allow string/number type from external.
108 if (type->isa<External>() && !IsValueNode<StringImm>(real_node) && !IsValueNode<FP32Imm>(real_node) &&
109 !IsValueNode<FP64Imm>(real_node)) {
110 MS_LOG(EXCEPTION) << "Illegal type in the graph: " << abstract->ToString()
111 << ", node: " << real_node->DebugString()
112 << "\nPlease check your code:" << trace::GetDebugInfoStr(node->debug_info());
113 }
114 // When a DeadNode is renormalized before, its abstract may be changed to
115 // AbstractScalar(std:: make_shared<Int32Imm>(0), std:: make_shared<Problem>()).
116 if (type->isa<Problem>()) {
117 auto value = abstract->GetValueTrack();
118 MS_EXCEPTION_IF_NULL(value);
119 node->set_abstract(value->ToAbstract());
120 }
121 return true;
122 }
123 return false;
124 }
125
ValidateAbstract(const AnfNodePtr & node)126 void ValidateAbstract(const AnfNodePtr &node) {
127 if (node == nullptr) {
128 MS_LOG(DEBUG) << "Node to validate is invalid";
129 return;
130 }
131 AbstractBasePtr abstract = node->abstract();
132 if (abstract == nullptr) {
133 MS_LOG(DEBUG) << "Abstract is null in node: " << node->DebugString();
134 return;
135 }
136 if (CheckAbstractScalar(node)) {
137 return;
138 }
139 if (abstract->isa<AbstractProblem>()) {
140 // NOTICE: validate dead code?
141 MS_LOG(DEBUG) << "AbstractProblem in the graph: " << abstract->ToString();
142 return;
143 }
144 bool is_legal_abstract = abstract->isa<AbstractType>() || abstract->isa<AbstractFunction>() ||
145 abstract->isa<AbstractTuple>() || abstract->isa<AbstractList>() ||
146 abstract->isa<AbstractTensor>() || abstract->isa<AbstractRowTensor>() ||
147 abstract->isa<AbstractRefTensor>() || abstract->isa<AbstractMapTensor>() ||
148 abstract->isa<abstract::AbstractNone>() || abstract->isa<abstract::AbstractMonad>() ||
149 abstract->isa<abstract::AbstractScript>();
150 if (is_legal_abstract) {
151 return;
152 }
153
154 // Other types show exception
155 MS_LOG(INTERNAL_EXCEPTION) << "Illegal type in the graph: " << abstract->ToString()
156 << ", node: " << node->DebugString()
157 << "\nPlease check your code:" << trace::GetDebugInfoStr(node->debug_info());
158 }
159
CheckValueTuple(const AnfNodePtr & node)160 void CheckValueTuple(const AnfNodePtr &node) {
161 MS_EXCEPTION_IF_NULL(node);
162 auto value_node = node->cast_ptr<ValueNode>();
163 MS_EXCEPTION_IF_NULL(value_node);
164 const auto &value = value_node->value();
165 MS_EXCEPTION_IF_NULL(value);
166 auto value_tuple = value->cast_ptr<ValueTuple>();
167 MS_EXCEPTION_IF_NULL(value_tuple);
168 const auto &tuple_values = value_tuple->value();
169 for (const auto &tuple_value : tuple_values) {
170 auto input_node = NewValueNode(tuple_value);
171 ValidateOperation(input_node);
172 }
173 }
174
CheckAssignReturnValue(const AnfNodePtr & node)175 void CheckAssignReturnValue(const AnfNodePtr &node) {
176 static const PrimitiveSet assign_prims = {prim::kPrimAssign, prim::kPrimAssignAdd, prim::kPrimAssignSub};
177 if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
178 auto real_input = node->cast_ptr<CNode>()->input(1);
179 while (IsPrimitiveCNode(real_input, prim::kPrimDepend)) {
180 real_input = real_input->cast_ptr<CNode>()->input(1);
181 }
182 if (!IsOneOfPrimitiveCNode(real_input, assign_prims)) {
183 return;
184 }
185 } else if (!IsOneOfPrimitiveCNode(node, assign_prims)) {
186 return;
187 }
188 auto fg = node->func_graph();
189 MS_EXCEPTION_IF_NULL(fg);
190 auto mgr = fg->manager();
191 MS_EXCEPTION_IF_NULL(mgr);
192 auto &node_users = mgr->node_users();
193 auto iter = node_users.find(node);
194 if (iter == node_users.end()) {
195 return;
196 }
197 static const PrimitiveSet virtual_prims = {
198 prim::kPrimImageSummary, prim::kPrimScalarSummary, prim::kPrimTensorSummary, prim::kPrimHistogramSummary,
199 prim::kPrimMakeTuple, prim::kPrimStateSetItem, prim::kPrimTupleGetItem, prim::kPrimLoad,
200 prim::kPrimPartial, prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimDynamicLossScale};
201 auto users = iter->second;
202 for (const auto &user : users) {
203 auto user_node = user.first;
204 if (!IsOneOfPrimitiveCNode(user_node, virtual_prims)) {
205 MS_LOG(WARNING) << "Deprecated: the return value of Assign/AssignAdd/AssignSub operator will be removed "
206 << "in subsequent releases.\n"
207 << "You can modify the code from:\na = P.Assign()(param, value)\nb = a * 2\nto: \n"
208 << "P.Assign()(param, value)\nb = param * 2\n"
209 << "Please check your code:" << trace::GetDebugInfoStr(node->debug_info());
210 }
211 }
212 }
213
CheckDeadNodeInOutputRecursively(const AnfNodePtr & node,const AbstractBasePtr & abstract)214 void CheckDeadNodeInOutputRecursively(const AnfNodePtr &node, const AbstractBasePtr &abstract) {
215 if (abstract == nullptr) {
216 return;
217 }
218 TypePtr type = abstract->BuildType();
219 MS_EXCEPTION_IF_NULL(type);
220 if (type->isa<Problem>() || type->isa<Function>()) {
221 MS_LOG(EXCEPTION) << "Function in output is not supported. Please check your code. "
222 << trace::GetDebugInfoStr(node->debug_info());
223 }
224 if (abstract->isa<AbstractSequence>()) {
225 auto abs_seq = abstract->cast_ptr<AbstractSequence>();
226 for (const auto &elem : abs_seq->elements()) {
227 CheckDeadNodeInOutputRecursively(node, elem);
228 }
229 }
230 }
231
ValidateTopGraphOutput(const AnfNodePtr & node)232 void ValidateTopGraphOutput(const AnfNodePtr &node) {
233 MS_EXCEPTION_IF_NULL(node);
234 auto abstract = node->abstract();
235 CheckDeadNodeInOutputRecursively(node, abstract);
236 }
237
Validate(const FuncGraphPtr & func_graph)238 void Validate(const FuncGraphPtr &func_graph) {
239 FuncGraphManagerPtr mgr = Manage(func_graph, false);
240 MS_EXCEPTION_IF_NULL(mgr);
241 ValidateTopGraphOutput(func_graph->output());
242 const AnfNodeSet &all_nodes = mgr->all_nodes();
243 for (auto node : all_nodes) {
244 TraceGuard guard(std::make_shared<TraceCopy>(node->debug_info()));
245 CheckAssignReturnValue(node);
246 while (IsPrimitiveCNode(node, prim::kPrimReturn) || IsPrimitiveCNode(node, prim::kPrimDepend)) {
247 node = node->cast_ptr<CNode>()->input(1);
248 }
249 if (IsValueNode<ValueTuple>(node)) {
250 CheckValueTuple(node);
251 continue;
252 }
253 ValidateOperation(node);
254 }
255 for (const auto &node : all_nodes) {
256 ValidateAbstract(node);
257 }
258 }
259 } // namespace validator
260 } // namespace mindspore
261