• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_LOAD_MINDIR_ANF_MODEL_PARSER_H
18 #define MINDSPORE_CORE_LOAD_MINDIR_ANF_MODEL_PARSER_H
19 
20 #include <string>
21 #include <map>
22 #include <unordered_map>
23 #include "google/protobuf/io/zero_copy_stream_impl.h"
24 #include "ir/func_graph.h"
25 #include "proto/mind_ir.pb.h"
26 
27 namespace mindspore {
28 using int32 = int32_t;
29 using int64 = int64_t;
30 using uint64 = uint64_t;
31 class MSANFModelParser {
32  public:
MSANFModelParser()33   MSANFModelParser() : producer_name_(""), model_version_(""), ir_version_("") {}
34   ~MSANFModelParser() = default;
35 
LoadTensorMapClear()36   static void LoadTensorMapClear() { load_tensor_map_.clear(); }
37   FuncGraphPtr Parse(const mind_ir::ModelProto &model_proto);
38   bool MSANFParseModelConfigureInfo(const mind_ir::ModelProto &model_proto);
39 
GetProducerName()40   std::string GetProducerName() { return producer_name_; }
GetProducerVersion()41   std::string GetProducerVersion() { return model_version_; }
GetIrVersion()42   std::string GetIrVersion() { return ir_version_; }
SetLite()43   void SetLite() { is_lite_ = true; }
IsLite()44   bool IsLite() const { return is_lite_; }
SetIncLoad()45   void SetIncLoad() { inc_load_ = true; }
IsIncLoad()46   bool IsIncLoad() const { return inc_load_; }
47 
48  private:
49   bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
50   bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
51   bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
52   bool BuildParameterForFuncGraph(const ParameterPtr &node, const mind_ir::TensorProto &tensor_proto);
53   bool BuildInputForFuncGraph(const ParameterPtr &node, const mind_ir::ValueInfoProto &value_proto);
54   tensor::TensorPtr BuildTensorInfoForFuncGraph(const mind_ir::TensorProto &tensor_proto);
55   CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::NodeProto &node_proto);
56   bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
57   bool GetAttrValueForCNode(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto);
58   bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto);
59   void ObtainCNodeAttrInScalarForm(const mind_ir::AttributeProto &attr_proto,
60                                    std::unordered_map<std::string, ValuePtr> *multi_value_map);
61   ValuePtr ParseAttrInScalarForm(const mind_ir::AttributeProto &attr_proto, int index);
62   ValuePtr ObtainCNodeAttrInSingleScalarForm(const mind_ir::AttributeProto &attr_proto) const;
63   bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto);
64   bool BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node_proto);
65   AnfNodePtr BuildOperatorNode(const mind_ir::NodeProto &node_proto);
66   void SetCNodeAbastract(const mind_ir::NodeProto &node_proto, CNodePtr cnode_ptr);
67   bool ObtainValueNodeInTensorForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor);
68   bool ObtainValueNodeInTupleTensorForm(const string &value_node_name, const mind_ir::AttributeProto &attr_proto);
69   bool GetAttrValueForValueNode(const std::string &value_node_name, const mind_ir::AttributeProto &attr_tensor);
70   bool ObtainValueNodeInTypeForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor);
71   bool ObtainValueNodeInNoneForm(const std::string &value_node_name);
72   bool ObtainValueNodeInMonadForm(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto);
73   std::unordered_map<std::string, abstract::AbstractBasePtr> GetAbstractForCNode(
74     const mind_ir::AttributeProto &attr_proto);
75   AnfNodePtr GetAnfNode(const std::string &node_name);
76 
77   std::string producer_name_;
78   std::string model_version_;
79   std::string ir_version_;
80   bool is_lite_ = false;
81   bool inc_load_ = false;
82   std::unordered_map<std::string, AnfNodePtr> anfnode_build_map_;
83   static std::map<std::string, tensor::TensorPtr> load_tensor_map_;
84 };
85 }  // namespace mindspore
86 
87 #endif  // MINDSPORE_CORE_LOAD_MINDIR_ANF_MODEL_PARSER_H
88