1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2019 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 #include "pipeline/jit/validator.h" 20 21 #include <memory> 22 #include <mutex> 23 24 #include "ir/manager.h" 25 #include "ir/dtype.h" 26 #include "pipeline/jit/static_analysis/prim.h" 27 28 namespace mindspore { 29 namespace validator { 30 using mindspore::abstract::AbstractBase; 31 using mindspore::abstract::AbstractClass; 32 using mindspore::abstract::AbstractError; 33 using mindspore::abstract::AbstractFunction; 34 using mindspore::abstract::AbstractJTagged; 35 using mindspore::abstract::AbstractList; 36 using mindspore::abstract::AbstractRef; 37 using mindspore::abstract::AbstractRowTensor; 38 using mindspore::abstract::AbstractScalar; 39 using mindspore::abstract::AbstractSparseTensor; 40 using mindspore::abstract::AbstractTensor; 41 using mindspore::abstract::AbstractTuple; 42 using mindspore::abstract::AbstractType; 43 ValidateOperation(const AnfNodePtr & node)44void ValidateOperation(const AnfNodePtr &node) { 45 if (!IsValueNode<Primitive>(node)) { 46 return; 47 } 48 49 // Primitive must in whitelist 50 auto prim = GetValueNode<PrimitivePtr>(node); 51 MS_EXCEPTION_IF_NULL(prim); 52 if (abstract::IsInWhiteList(prim)) { 53 return; 54 } 55 if (prim->HasAttr("is_load")) { 56 return; 57 } 58 if (prim->HasPyEvaluator()) { 59 MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python evaluator."; 60 return; 61 } 62 if (prim->prim_type() == PrimType::kPrimTypePyCheck) { 63 MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python inference checking method."; 64 return; 65 } 66 if (prim->name() == "fake_bprop") { 67 MS_LOG(EXCEPTION) << "Illegal primitive: " << GetValue<std::string>(prim->GetAttr("info")); 68 } 69 70 MS_LOG(EXCEPTION) << "Illegal primitive: " << prim->name(); 71 } 72 CheckAbstractScalar(const AnfNodePtr & node)73bool CheckAbstractScalar(const AnfNodePtr &node) { 74 MS_EXCEPTION_IF_NULL(node); 75 AbstractBasePtr ptrBase = node->abstract(); 76 if (ptrBase->isa<AbstractScalar>()) { 77 TypePtr ptrType = ptrBase->GetTypeTrack(); 78 MS_EXCEPTION_IF_NULL(ptrType); 79 if (ptrType->isa<EnvType>()) { 80 MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString() << " for node=" << node->DebugString(); 81 } 82 if (ptrType->isa<Problem>() || ptrType->isa<External>()) { 83 // only send string in external 84 if (!IsValueNode<StringImm>(node)) { 85 // Validate a type. 86 MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString() 87 << " for node=" << node->DebugString(); 88 } 89 } 90 return true; 91 } 92 return false; 93 } 94 ValidateAbstract(const AnfNodePtr & node)95void ValidateAbstract(const AnfNodePtr &node) { 96 if (node == nullptr) { 97 MS_LOG(DEBUG) << "Node to validate is invalid"; 98 return; 99 } 100 AbstractBasePtr ptrBase = node->abstract(); 101 if (ptrBase == nullptr) { 102 MS_LOG(DEBUG) << "Abstract is null in node: " << node->DebugString(); 103 return; 104 } 105 if (ptrBase->isa<AbstractClass>() || ptrBase->isa<AbstractJTagged>()) { 106 // Validate a type. 107 MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString() << " for node=" << node->DebugString(); 108 } 109 if (CheckAbstractScalar(node)) { 110 return; 111 } 112 if (ptrBase->isa<AbstractError>()) { 113 // NOTICE: validate dead code? 114 MS_LOG(DEBUG) << "AbstractError in the graph: " << ptrBase->ToString(); 115 return; 116 } 117 bool checkAbstractIslegal = 118 ptrBase->isa<AbstractType>() || ptrBase->isa<AbstractFunction>() || ptrBase->isa<AbstractTuple>() || 119 ptrBase->isa<AbstractList>() || ptrBase->isa<AbstractTensor>() || ptrBase->isa<AbstractRowTensor>() || 120 ptrBase->isa<AbstractSparseTensor>() || ptrBase->isa<abstract::AbstractRefKey>() || ptrBase->isa<AbstractRef>() || 121 ptrBase->isa<abstract::AbstractNone>() || ptrBase->isa<abstract::AbstractMonad>(); 122 if (checkAbstractIslegal) { 123 return; 124 } 125 126 // Other types show exception 127 MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString(); 128 } 129 Validate(const FuncGraphPtr & fg)130void Validate(const FuncGraphPtr &fg) { 131 FuncGraphManagerPtr mgr = Manage(fg, false); 132 MS_EXCEPTION_IF_NULL(mgr); 133 AnfNodeSet &all_nodes = mgr->all_nodes(); 134 for (const auto &anf_node : all_nodes) { 135 ValidateOperation(anf_node); 136 ValidateAbstract(anf_node); 137 } 138 } 139 } // namespace validator 140 } // namespace mindspore 141