• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONNX_NODE_PARSER_H_
18 #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONNX_NODE_PARSER_H_
19 
20 #include <string>
21 #include <utility>
22 #include <vector>
23 #include <map>
24 #include "google/protobuf/message.h"
25 #include "proto/onnx.pb.h"
26 #include "include/errorcode.h"
27 #include "src/common/log_adapter.h"
28 #include "schema/inner/model_generated.h"
29 #include "ir/dtype/type_id.h"
30 #include "ops/primitive_c.h"
31 #include "mindspore/core/utils/check_convert_utils.h"
32 #include "tools/converter/parser/parser_utils.h"
33 #include "ops/op_utils.h"
34 
35 namespace mindspore {
36 namespace lite {
37 class ExternalDataInfo {
38  public:
GetRelativePath()39   std::string GetRelativePath() const { return relative_path_; }
GetOffset()40   size_t GetOffset() const { return static_cast<size_t>(offset_); }
GetLength()41   size_t GetLength() const { return length_; }
42   static STATUS Create(const google::protobuf::RepeatedPtrField<onnx::StringStringEntryProto> &external_data,
43                        ExternalDataInfo *external_data_info);
44 
45  private:
46   static bool StringMapKeyIs(const std::string &key, const onnx::StringStringEntryProto &string_map);
47   std::string relative_path_;
48   off_t offset_ = 0;
49   size_t length_ = 0;
50   std::string checksum_;
51 };
52 
53 class OnnxNodeParser {
54  public:
OnnxNodeParser(std::string node_name)55   explicit OnnxNodeParser(std::string node_name) : name_(std::move(node_name)) {}
56 
57   virtual ~OnnxNodeParser() = default;
58 
Parse(const onnx::GraphProto & onnx_graph,const onnx::NodeProto & onnx_node)59   virtual PrimitiveCPtr Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { return nullptr; }
60 
set_opset_version(int64_t version)61   static STATUS set_opset_version(int64_t version) {
62     opset_version_ = version;
63     return RET_OK;
64   }
65 
opset_version()66   static int64_t opset_version() { return opset_version_; }
67 
68   static tensor::TensorPtr CopyOnnxTensorData(const onnx::TensorProto &onnx_const_tensor);
69 
70   static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type);
71 
72   static size_t GetOnnxElementNum(const onnx::TensorProto &onnx_tensor, bool *overflowed);
73 
74   static STATUS LoadOnnxExternalTensorData(const onnx::TensorProto &onnx_const_tensor,
75                                            const tensor::TensorPtr &tensor_info, const std::string &model_file,
76                                            std::map<std::string, std::pair<size_t, uint8_t *>> *external_datas);
77 
78   static const onnx::TensorProto *GetConstantTensorData(const onnx::GraphProto &onnx_graph,
79                                                         const std::string &input_name);
80 
SetOnnxModelFile(const std::string model_file)81   static void SetOnnxModelFile(const std::string model_file) { model_file_ = model_file; }
82 
GetOnnxModelFile()83   static std::string GetOnnxModelFile() { return model_file_; }
84 
85   static void SetTypeAndValueForFloat(const onnx::TensorProto &onnx_tensor, std::vector<float> *value,
86                                       size_t data_count);
87 
88   static void SetTypeAndValueForBool(const onnx::TensorProto &onnx_tensor, std::vector<float> *value,
89                                      size_t data_count);
90 
91   static STATUS SetDataTypeAndValue(const onnx::TensorProto &onnx_tensor, std::vector<float> *value, size_t data_count,
92                                     int *type);
93 
94  protected:
95   static mindspore::PadMode GetOnnxPadMode(const onnx::AttributeProto &onnx_node_attr);
96 
97   static STATUS GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tensor, std::vector<float> *value, int *type);
98 
99   static int GetOnnxRawData(const onnx::TensorProto &onnx_const_tensor, size_t data_count,
100                             const tensor::TensorPtr &tensor_info);
101   static int GetOnnxListData(const onnx::TensorProto &onnx_const_tensor, size_t data_count,
102                              const tensor::TensorPtr &tensor_info);
103 
104   static STATUS SetExternalTensorFile(const std::string &model_file, std::string *external_tensor_dir);
105 
106   static const void *LoadOnnxRawData(const onnx::TensorProto &onnx_const_tensor, size_t *data_size,
107                                      const std::string &model_file,
108                                      std::map<std::string, std::pair<size_t, uint8_t *>> *external_datas);
109 
110   const std::string name_{};
111 
112  private:
113   static int64_t opset_version_;
114   static inline std::string model_file_ = "";
115 };
116 }  // namespace lite
117 }  // namespace mindspore
118 #endif  // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONNX_NODE_PARSER_H_
119