1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_LOAD_MINDIR_ANF_MODEL_PARSER_H 18 #define MINDSPORE_CORE_LOAD_MINDIR_ANF_MODEL_PARSER_H 19 20 #include <string> 21 #include <map> 22 #include <unordered_map> 23 #include "google/protobuf/io/zero_copy_stream_impl.h" 24 #include "ir/func_graph.h" 25 #include "proto/mind_ir.pb.h" 26 27 namespace mindspore { 28 using int32 = int32_t; 29 using int64 = int64_t; 30 using uint64 = uint64_t; 31 class MSANFModelParser { 32 public: MSANFModelParser()33 MSANFModelParser() : producer_name_(""), model_version_(""), ir_version_("") {} 34 ~MSANFModelParser() = default; 35 LoadTensorMapClear()36 static void LoadTensorMapClear() { load_tensor_map_.clear(); } 37 FuncGraphPtr Parse(const mind_ir::ModelProto &model_proto); 38 bool MSANFParseModelConfigureInfo(const mind_ir::ModelProto &model_proto); 39 GetProducerName()40 std::string GetProducerName() { return producer_name_; } GetProducerVersion()41 std::string GetProducerVersion() { return model_version_; } GetIrVersion()42 std::string GetIrVersion() { return ir_version_; } SetLite()43 void SetLite() { is_lite_ = true; } IsLite()44 bool IsLite() const { return is_lite_; } SetIncLoad()45 void SetIncLoad() { inc_load_ = true; } IsIncLoad()46 bool IsIncLoad() const { return inc_load_; } 47 48 private: 49 bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto); 50 bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto); 51 bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto); 52 bool BuildParameterForFuncGraph(const ParameterPtr &node, const mind_ir::TensorProto &tensor_proto); 53 bool BuildInputForFuncGraph(const ParameterPtr &node, const mind_ir::ValueInfoProto &value_proto); 54 tensor::TensorPtr BuildTensorInfoForFuncGraph(const mind_ir::TensorProto &tensor_proto); 55 CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::NodeProto &node_proto); 56 bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto); 57 bool GetAttrValueForCNode(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto); 58 bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto); 59 void ObtainCNodeAttrInScalarForm(const mind_ir::AttributeProto &attr_proto, 60 std::unordered_map<std::string, ValuePtr> *multi_value_map); 61 ValuePtr ParseAttrInScalarForm(const mind_ir::AttributeProto &attr_proto, int index); 62 ValuePtr ObtainCNodeAttrInSingleScalarForm(const mind_ir::AttributeProto &attr_proto) const; 63 bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto); 64 bool BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node_proto); 65 AnfNodePtr BuildOperatorNode(const mind_ir::NodeProto &node_proto); 66 void SetCNodeAbastract(const mind_ir::NodeProto &node_proto, CNodePtr cnode_ptr); 67 bool ObtainValueNodeInTensorForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor); 68 bool ObtainValueNodeInTupleTensorForm(const string &value_node_name, const mind_ir::AttributeProto &attr_proto); 69 bool GetAttrValueForValueNode(const std::string &value_node_name, const mind_ir::AttributeProto &attr_tensor); 70 bool ObtainValueNodeInTypeForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor); 71 bool ObtainValueNodeInNoneForm(const std::string &value_node_name); 72 bool ObtainValueNodeInMonadForm(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto); 73 std::unordered_map<std::string, abstract::AbstractBasePtr> GetAbstractForCNode( 74 const mind_ir::AttributeProto &attr_proto); 75 AnfNodePtr GetAnfNode(const std::string &node_name); 76 77 std::string producer_name_; 78 std::string model_version_; 79 std::string ir_version_; 80 bool is_lite_ = false; 81 bool inc_load_ = false; 82 std::unordered_map<std::string, AnfNodePtr> anfnode_build_map_; 83 static std::map<std::string, tensor::TensorPtr> load_tensor_map_; 84 }; 85 } // namespace mindspore 86 87 #endif // MINDSPORE_CORE_LOAD_MINDIR_ANF_MODEL_PARSER_H 88