1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONNX_NODE_PARSER_H_ 18 #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONNX_NODE_PARSER_H_ 19 20 #include <string> 21 #include <utility> 22 #include <vector> 23 #include <map> 24 #include "google/protobuf/message.h" 25 #include "proto/onnx.pb.h" 26 #include "include/errorcode.h" 27 #include "src/common/log_adapter.h" 28 #include "schema/inner/model_generated.h" 29 #include "ir/dtype/type_id.h" 30 #include "ops/primitive_c.h" 31 #include "mindspore/core/utils/check_convert_utils.h" 32 #include "tools/converter/parser/parser_utils.h" 33 #include "ops/op_utils.h" 34 35 namespace mindspore { 36 namespace lite { 37 class ExternalDataInfo { 38 public: GetRelativePath()39 std::string GetRelativePath() const { return relative_path_; } GetOffset()40 size_t GetOffset() const { return static_cast<size_t>(offset_); } GetLength()41 size_t GetLength() const { return length_; } 42 static STATUS Create(const google::protobuf::RepeatedPtrField<onnx::StringStringEntryProto> &external_data, 43 ExternalDataInfo *external_data_info); 44 45 private: 46 static bool StringMapKeyIs(const std::string &key, const onnx::StringStringEntryProto &string_map); 47 std::string relative_path_; 48 off_t offset_ = 0; 49 size_t length_ = 0; 50 std::string checksum_; 51 }; 52 53 class OnnxNodeParser { 54 public: OnnxNodeParser(std::string node_name)55 explicit OnnxNodeParser(std::string node_name) : name_(std::move(node_name)) {} 56 57 virtual ~OnnxNodeParser() = default; 58 Parse(const onnx::GraphProto & onnx_graph,const onnx::NodeProto & onnx_node)59 virtual PrimitiveCPtr Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { return nullptr; } 60 set_opset_version(int64_t version)61 static STATUS set_opset_version(int64_t version) { 62 opset_version_ = version; 63 return RET_OK; 64 } 65 opset_version()66 static int64_t opset_version() { return opset_version_; } 67 68 static tensor::TensorPtr CopyOnnxTensorData(const onnx::TensorProto &onnx_const_tensor); 69 70 static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type); 71 72 static size_t GetOnnxElementNum(const onnx::TensorProto &onnx_tensor, bool *overflowed); 73 74 static STATUS LoadOnnxExternalTensorData(const onnx::TensorProto &onnx_const_tensor, 75 const tensor::TensorPtr &tensor_info, const std::string &model_file, 76 std::map<std::string, std::pair<size_t, uint8_t *>> *external_datas); 77 78 static const onnx::TensorProto *GetConstantTensorData(const onnx::GraphProto &onnx_graph, 79 const std::string &input_name); 80 SetOnnxModelFile(const std::string model_file)81 static void SetOnnxModelFile(const std::string model_file) { model_file_ = model_file; } 82 GetOnnxModelFile()83 static std::string GetOnnxModelFile() { return model_file_; } 84 85 static void SetTypeAndValueForFloat(const onnx::TensorProto &onnx_tensor, std::vector<float> *value, 86 size_t data_count); 87 88 static void SetTypeAndValueForBool(const onnx::TensorProto &onnx_tensor, std::vector<float> *value, 89 size_t data_count); 90 91 static STATUS SetDataTypeAndValue(const onnx::TensorProto &onnx_tensor, std::vector<float> *value, size_t data_count, 92 int *type); 93 94 protected: 95 static mindspore::PadMode GetOnnxPadMode(const onnx::AttributeProto &onnx_node_attr); 96 97 static STATUS GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tensor, std::vector<float> *value, int *type); 98 99 static int GetOnnxRawData(const onnx::TensorProto &onnx_const_tensor, size_t data_count, 100 const tensor::TensorPtr &tensor_info); 101 static int GetOnnxListData(const onnx::TensorProto &onnx_const_tensor, size_t data_count, 102 const tensor::TensorPtr &tensor_info); 103 104 static STATUS SetExternalTensorFile(const std::string &model_file, std::string *external_tensor_dir); 105 106 static const void *LoadOnnxRawData(const onnx::TensorProto &onnx_const_tensor, size_t *data_size, 107 const std::string &model_file, 108 std::map<std::string, std::pair<size_t, uint8_t *>> *external_datas); 109 110 const std::string name_{}; 111 112 private: 113 static int64_t opset_version_; 114 static inline std::string model_file_ = ""; 115 }; 116 } // namespace lite 117 } // namespace mindspore 118 #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ONNX_NODE_PARSER_H_ 119