• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "pipeline/jit/validator.h"
20 
21 #include <memory>
22 #include <mutex>
23 
24 #include "ir/manager.h"
25 #include "ir/dtype.h"
26 #include "pipeline/jit/static_analysis/prim.h"
27 
28 namespace mindspore {
29 namespace validator {
30 using mindspore::abstract::AbstractBase;
31 using mindspore::abstract::AbstractClass;
32 using mindspore::abstract::AbstractError;
33 using mindspore::abstract::AbstractFunction;
34 using mindspore::abstract::AbstractJTagged;
35 using mindspore::abstract::AbstractList;
36 using mindspore::abstract::AbstractRef;
37 using mindspore::abstract::AbstractRowTensor;
38 using mindspore::abstract::AbstractScalar;
39 using mindspore::abstract::AbstractSparseTensor;
40 using mindspore::abstract::AbstractTensor;
41 using mindspore::abstract::AbstractTuple;
42 using mindspore::abstract::AbstractType;
43 
ValidateOperation(const AnfNodePtr & node)44 void ValidateOperation(const AnfNodePtr &node) {
45   if (!IsValueNode<Primitive>(node)) {
46     return;
47   }
48 
49   // Primitive must in whitelist
50   auto prim = GetValueNode<PrimitivePtr>(node);
51   MS_EXCEPTION_IF_NULL(prim);
52   if (abstract::IsInWhiteList(prim)) {
53     return;
54   }
55   if (prim->HasAttr("is_load")) {
56     return;
57   }
58   if (prim->HasPyEvaluator()) {
59     MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python evaluator.";
60     return;
61   }
62   if (prim->prim_type() == PrimType::kPrimTypePyCheck) {
63     MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python inference checking method.";
64     return;
65   }
66   if (prim->name() == "fake_bprop") {
67     MS_LOG(EXCEPTION) << "Illegal primitive: " << GetValue<std::string>(prim->GetAttr("info"));
68   }
69 
70   MS_LOG(EXCEPTION) << "Illegal primitive: " << prim->name();
71 }
72 
CheckAbstractScalar(const AnfNodePtr & node)73 bool CheckAbstractScalar(const AnfNodePtr &node) {
74   MS_EXCEPTION_IF_NULL(node);
75   AbstractBasePtr ptrBase = node->abstract();
76   if (ptrBase->isa<AbstractScalar>()) {
77     TypePtr ptrType = ptrBase->GetTypeTrack();
78     MS_EXCEPTION_IF_NULL(ptrType);
79     if (ptrType->isa<EnvType>()) {
80       MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString() << " for node=" << node->DebugString();
81     }
82     if (ptrType->isa<Problem>() || ptrType->isa<External>()) {
83       // only send string in external
84       if (!IsValueNode<StringImm>(node)) {
85         // Validate a type.
86         MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString()
87                           << " for node=" << node->DebugString();
88       }
89     }
90     return true;
91   }
92   return false;
93 }
94 
ValidateAbstract(const AnfNodePtr & node)95 void ValidateAbstract(const AnfNodePtr &node) {
96   if (node == nullptr) {
97     MS_LOG(DEBUG) << "Node to validate is invalid";
98     return;
99   }
100   AbstractBasePtr ptrBase = node->abstract();
101   if (ptrBase == nullptr) {
102     MS_LOG(DEBUG) << "Abstract is null in node: " << node->DebugString();
103     return;
104   }
105   if (ptrBase->isa<AbstractClass>() || ptrBase->isa<AbstractJTagged>()) {
106     // Validate a type.
107     MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString() << " for node=" << node->DebugString();
108   }
109   if (CheckAbstractScalar(node)) {
110     return;
111   }
112   if (ptrBase->isa<AbstractError>()) {
113     // NOTICE: validate dead code?
114     MS_LOG(DEBUG) << "AbstractError in the graph: " << ptrBase->ToString();
115     return;
116   }
117   bool checkAbstractIslegal =
118     ptrBase->isa<AbstractType>() || ptrBase->isa<AbstractFunction>() || ptrBase->isa<AbstractTuple>() ||
119     ptrBase->isa<AbstractList>() || ptrBase->isa<AbstractTensor>() || ptrBase->isa<AbstractRowTensor>() ||
120     ptrBase->isa<AbstractSparseTensor>() || ptrBase->isa<abstract::AbstractRefKey>() || ptrBase->isa<AbstractRef>() ||
121     ptrBase->isa<abstract::AbstractNone>() || ptrBase->isa<abstract::AbstractMonad>();
122   if (checkAbstractIslegal) {
123     return;
124   }
125 
126   // Other types show exception
127   MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString();
128 }
129 
Validate(const FuncGraphPtr & fg)130 void Validate(const FuncGraphPtr &fg) {
131   FuncGraphManagerPtr mgr = Manage(fg, false);
132   MS_EXCEPTION_IF_NULL(mgr);
133   AnfNodeSet &all_nodes = mgr->all_nodes();
134   for (const auto &anf_node : all_nodes) {
135     ValidateOperation(anf_node);
136     ValidateAbstract(anf_node);
137   }
138 }
139 }  // namespace validator
140 }  // namespace mindspore
141