1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NODE_PARSER_H 18 #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NODE_PARSER_H 19 20 #include <string> 21 #include <utility> 22 #include <vector> 23 #include "google/protobuf/message.h" 24 #include "proto/onnx.pb.h" 25 #include "include/errorcode.h" 26 #include "src/common/log_adapter.h" 27 #include "schema/inner/model_generated.h" 28 #include "ir/dtype/type_id.h" 29 #include "ops/primitive_c.h" 30 #include "mindspore/core/utils/check_convert_utils.h" 31 32 namespace mindspore { 33 namespace lite { 34 class OnnxNodeParser { 35 public: OnnxNodeParser(std::string node_name)36 explicit OnnxNodeParser(std::string node_name) : name_(std::move(node_name)) {} 37 38 virtual ~OnnxNodeParser() = default; 39 Parse(const onnx::GraphProto & onnx_graph,const onnx::NodeProto & onnx_node)40 virtual ops::PrimitiveC *Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) { 41 return nullptr; 42 } 43 set_opset_version(int64_t version)44 static STATUS set_opset_version(int64_t version) { 45 opset_version_ = version; 46 return RET_OK; 47 } 48 opset_version()49 static int64_t opset_version() { return opset_version_; } 50 51 static STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_const_tensor, const tensor::TensorPtr &tensor_info); 52 53 static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type); 54 55 static size_t GetOnnxElementNum(const onnx::TensorProto &onnx_tensor, bool *overflowed); 56 57 protected: 58 static mindspore::PadMode GetOnnxPadMode(const onnx::AttributeProto &onnx_node_attr); 59 60 static STATUS GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tensor, std::vector<float> *value, int *type); 61 62 static const void *GetOnnxRawData(const onnx::TensorProto &onnx_const_tensor, TypeId data_type, size_t data_count, 63 size_t *data_size); 64 65 const std::string name_{}; 66 67 private: 68 static int64_t opset_version_; 69 }; 70 } // namespace lite 71 } // namespace mindspore 72 #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NODE_PARSER_H 73