• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2024 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "pipeline/jit/ps/validator.h"
20 
21 #include <memory>
22 #include <mutex>
23 #include <string>
24 
25 #include "mindspore/core/ops/structure_ops.h"
26 #include "mindspore/core/ops/sequence_ops.h"
27 #include "mindspore/core/ops/other_ops.h"
28 #include "mindspore/core/ops/nn_optimizer_ops.h"
29 #include "mindspore/core/ops/framework_ops.h"
30 #include "ir/manager.h"
31 #include "ir/dtype.h"
32 #include "pipeline/jit/ps/static_analysis/prim.h"
33 #include "pipeline/jit/ps/parse/resolve.h"
34 #include "pipeline/jit/ps/debug/trace.h"
35 
36 namespace mindspore {
37 namespace validator {
38 using mindspore::abstract::AbstractBase;
39 using mindspore::abstract::AbstractFunction;
40 using mindspore::abstract::AbstractJTagged;
41 using mindspore::abstract::AbstractList;
42 using mindspore::abstract::AbstractMapTensor;
43 using mindspore::abstract::AbstractProblem;
44 using mindspore::abstract::AbstractRefTensor;
45 using mindspore::abstract::AbstractRowTensor;
46 using mindspore::abstract::AbstractScalar;
47 using mindspore::abstract::AbstractSequence;
48 using mindspore::abstract::AbstractTensor;
49 using mindspore::abstract::AbstractTuple;
50 using mindspore::abstract::AbstractType;
51 
ValidateOperation(const AnfNodePtr & node)52 void ValidateOperation(const AnfNodePtr &node) {
53   if (!IsValueNode<Primitive>(node)) {
54     return;
55   }
56 
57   // Primitive must in whitelist
58   auto prim = GetValueNode<PrimitivePtr>(node);
59   MS_EXCEPTION_IF_NULL(prim);
60   if (prim->isa<prim::DoSignaturePrimitive>()) {
61     MS_LOG(INTERNAL_EXCEPTION) << "Illegal DoSignaturePrimitive '" << prim->name() << "' in the graph."
62                                << "node:" << node->DebugString()
63                                << ", location:" << trace::GetDebugInfoStr(node->debug_info());
64   }
65   if (abstract::IsInWhiteList(prim)) {
66     return;
67   }
68   if (prim->HasAttr("is_load")) {
69     return;
70   }
71   if (prim->name() == "PyExecute") {
72     return;
73   }
74   if (prim->name() == "TensorMove") {
75     return;
76   }
77 
78   if (prim->isa<PrimitivePy>()) {
79     MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python evaluator.";
80     return;
81   }
82   if (prim->name() == "fake_bprop") {
83     MS_LOG(INTERNAL_EXCEPTION) << "Illegal primitive: " << GetValue<std::string>(prim->GetAttr("info"))
84                                << "node:" << node->DebugString()
85                                << ", location:" << trace::GetDebugInfoStr(node->debug_info());
86   }
87 
88   MS_LOG(EXCEPTION) << "Illegal primitive: " << prim->name()
89                     << ". Please check whether to use unsupported primitive:" << node->DebugString()
90                     << ", location:" << trace::GetDebugInfoStr(node->debug_info());
91 }
92 
CheckAbstractScalar(const AnfNodePtr & node)93 bool CheckAbstractScalar(const AnfNodePtr &node) {
94   MS_EXCEPTION_IF_NULL(node);
95   AbstractBasePtr abstract = node->abstract();
96   if (abstract->isa<AbstractScalar>()) {
97     TypePtr type = abstract->GetTypeTrack();
98     MS_EXCEPTION_IF_NULL(type);
99     if (type->isa<EnvType>() || type->isa<MsClassType>()) {
100       MS_LOG(EXCEPTION) << "Illegal type in the graph: " << abstract->ToString() << ", node: " << node->DebugString()
101                         << ", location:" << trace::GetDebugInfoStr(node->debug_info());
102     }
103     auto real_node = node;
104     if (IsPrimitiveCNode(node, prim::kPrimReturn) || IsPrimitiveCNode(node, prim::kPrimDepend)) {
105       real_node = real_node->cast<CNodePtr>()->input(1);
106     }
107     // Only allow string/number type from external.
108     if (type->isa<External>() && !IsValueNode<StringImm>(real_node) && !IsValueNode<FP32Imm>(real_node) &&
109         !IsValueNode<FP64Imm>(real_node)) {
110       MS_LOG(EXCEPTION) << "Illegal type in the graph: " << abstract->ToString()
111                         << ", node: " << real_node->DebugString()
112                         << "\nPlease check your code:" << trace::GetDebugInfoStr(node->debug_info());
113     }
114     // When a DeadNode is renormalized before, its abstract may be changed to
115     // AbstractScalar(std:: make_shared<Int32Imm>(0), std:: make_shared<Problem>()).
116     if (type->isa<Problem>()) {
117       auto value = abstract->GetValueTrack();
118       MS_EXCEPTION_IF_NULL(value);
119       node->set_abstract(value->ToAbstract());
120     }
121     return true;
122   }
123   return false;
124 }
125 
ValidateAbstract(const AnfNodePtr & node)126 void ValidateAbstract(const AnfNodePtr &node) {
127   if (node == nullptr) {
128     MS_LOG(DEBUG) << "Node to validate is invalid";
129     return;
130   }
131   AbstractBasePtr abstract = node->abstract();
132   if (abstract == nullptr) {
133     MS_LOG(DEBUG) << "Abstract is null in node: " << node->DebugString();
134     return;
135   }
136   if (CheckAbstractScalar(node)) {
137     return;
138   }
139   if (abstract->isa<AbstractProblem>()) {
140     // NOTICE: validate dead code?
141     MS_LOG(DEBUG) << "AbstractProblem in the graph: " << abstract->ToString();
142     return;
143   }
144   bool is_legal_abstract = abstract->isa<AbstractType>() || abstract->isa<AbstractFunction>() ||
145                            abstract->isa<AbstractTuple>() || abstract->isa<AbstractList>() ||
146                            abstract->isa<AbstractTensor>() || abstract->isa<AbstractRowTensor>() ||
147                            abstract->isa<AbstractRefTensor>() || abstract->isa<AbstractMapTensor>() ||
148                            abstract->isa<abstract::AbstractNone>() || abstract->isa<abstract::AbstractMonad>() ||
149                            abstract->isa<abstract::AbstractScript>();
150   if (is_legal_abstract) {
151     return;
152   }
153 
154   // Other types show exception
155   MS_LOG(INTERNAL_EXCEPTION) << "Illegal type in the graph: " << abstract->ToString()
156                              << ", node: " << node->DebugString()
157                              << "\nPlease check your code:" << trace::GetDebugInfoStr(node->debug_info());
158 }
159 
CheckValueTuple(const AnfNodePtr & node)160 void CheckValueTuple(const AnfNodePtr &node) {
161   MS_EXCEPTION_IF_NULL(node);
162   auto value_node = node->cast_ptr<ValueNode>();
163   MS_EXCEPTION_IF_NULL(value_node);
164   const auto &value = value_node->value();
165   MS_EXCEPTION_IF_NULL(value);
166   auto value_tuple = value->cast_ptr<ValueTuple>();
167   MS_EXCEPTION_IF_NULL(value_tuple);
168   const auto &tuple_values = value_tuple->value();
169   for (const auto &tuple_value : tuple_values) {
170     auto input_node = NewValueNode(tuple_value);
171     ValidateOperation(input_node);
172   }
173 }
174 
CheckAssignReturnValue(const AnfNodePtr & node)175 void CheckAssignReturnValue(const AnfNodePtr &node) {
176   static const PrimitiveSet assign_prims = {prim::kPrimAssign, prim::kPrimAssignAdd, prim::kPrimAssignSub};
177   if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
178     auto real_input = node->cast_ptr<CNode>()->input(1);
179     while (IsPrimitiveCNode(real_input, prim::kPrimDepend)) {
180       real_input = real_input->cast_ptr<CNode>()->input(1);
181     }
182     if (!IsOneOfPrimitiveCNode(real_input, assign_prims)) {
183       return;
184     }
185   } else if (!IsOneOfPrimitiveCNode(node, assign_prims)) {
186     return;
187   }
188   auto fg = node->func_graph();
189   MS_EXCEPTION_IF_NULL(fg);
190   auto mgr = fg->manager();
191   MS_EXCEPTION_IF_NULL(mgr);
192   auto &node_users = mgr->node_users();
193   auto iter = node_users.find(node);
194   if (iter == node_users.end()) {
195     return;
196   }
197   static const PrimitiveSet virtual_prims = {
198     prim::kPrimImageSummary, prim::kPrimScalarSummary, prim::kPrimTensorSummary, prim::kPrimHistogramSummary,
199     prim::kPrimMakeTuple,    prim::kPrimStateSetItem,  prim::kPrimTupleGetItem,  prim::kPrimLoad,
200     prim::kPrimPartial,      prim::kPrimDepend,        prim::kPrimUpdateState,   prim::kPrimDynamicLossScale};
201   auto users = iter->second;
202   for (const auto &user : users) {
203     auto user_node = user.first;
204     if (!IsOneOfPrimitiveCNode(user_node, virtual_prims)) {
205       MS_LOG(WARNING) << "Deprecated: the return value of Assign/AssignAdd/AssignSub operator will be removed "
206                       << "in subsequent releases.\n"
207                       << "You can modify the code from:\na = P.Assign()(param, value)\nb = a * 2\nto: \n"
208                       << "P.Assign()(param, value)\nb = param * 2\n"
209                       << "Please check your code:" << trace::GetDebugInfoStr(node->debug_info());
210     }
211   }
212 }
213 
CheckDeadNodeInOutputRecursively(const AnfNodePtr & node,const AbstractBasePtr & abstract)214 void CheckDeadNodeInOutputRecursively(const AnfNodePtr &node, const AbstractBasePtr &abstract) {
215   if (abstract == nullptr) {
216     return;
217   }
218   TypePtr type = abstract->BuildType();
219   MS_EXCEPTION_IF_NULL(type);
220   if (type->isa<Problem>() || type->isa<Function>()) {
221     MS_LOG(EXCEPTION) << "Function in output is not supported. Please check your code. "
222                       << trace::GetDebugInfoStr(node->debug_info());
223   }
224   if (abstract->isa<AbstractSequence>()) {
225     auto abs_seq = abstract->cast_ptr<AbstractSequence>();
226     for (const auto &elem : abs_seq->elements()) {
227       CheckDeadNodeInOutputRecursively(node, elem);
228     }
229   }
230 }
231 
ValidateTopGraphOutput(const AnfNodePtr & node)232 void ValidateTopGraphOutput(const AnfNodePtr &node) {
233   MS_EXCEPTION_IF_NULL(node);
234   auto abstract = node->abstract();
235   CheckDeadNodeInOutputRecursively(node, abstract);
236 }
237 
Validate(const FuncGraphPtr & func_graph)238 void Validate(const FuncGraphPtr &func_graph) {
239   FuncGraphManagerPtr mgr = Manage(func_graph, false);
240   MS_EXCEPTION_IF_NULL(mgr);
241   ValidateTopGraphOutput(func_graph->output());
242   const AnfNodeSet &all_nodes = mgr->all_nodes();
243   for (auto node : all_nodes) {
244     TraceGuard guard(std::make_shared<TraceCopy>(node->debug_info()));
245     CheckAssignReturnValue(node);
246     while (IsPrimitiveCNode(node, prim::kPrimReturn) || IsPrimitiveCNode(node, prim::kPrimDepend)) {
247       node = node->cast_ptr<CNode>()->input(1);
248     }
249     if (IsValueNode<ValueTuple>(node)) {
250       CheckValueTuple(node);
251       continue;
252     }
253     ValidateOperation(node);
254   }
255   for (const auto &node : all_nodes) {
256     ValidateAbstract(node);
257   }
258 }
259 }  // namespace validator
260 }  // namespace mindspore
261