• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_RETURN_CODE_H
18 #define MINDSPORE_LITE_TOOLS_CONVERTER_RETURN_CODE_H
19 
20 #include <string>
21 #include <set>
22 #include <map>
23 #include <vector>
24 #include "include/errorcode.h"
25 #include "src/common/log_adapter.h"
26 #include "ir/dtype/type_id.h"
27 
28 namespace mindspore {
29 namespace lite {
30 class ReturnCode {
31  public:
GetSingleReturnCode()32   static ReturnCode *GetSingleReturnCode() {
33     static ReturnCode return_code;
34     return &return_code;
35   }
UpdateReturnCode(STATUS status)36   void UpdateReturnCode(STATUS status) {
37     if (status_code_ == RET_OK) {
38       status_code_ = status;
39     }
40   }
status_code()41   STATUS status_code() const { return status_code_; }
42 
43  private:
44   ReturnCode() = default;
45   virtual ~ReturnCode() = default;
46   int status_code_ = RET_OK;
47 };
48 
49 class NotSupportOp {
50  public:
GetInstance()51   static NotSupportOp *GetInstance() {
52     static NotSupportOp not_support_op;
53     return &not_support_op;
54   }
set_fmk_type(const std::string & fmk_type)55   void set_fmk_type(const std::string &fmk_type) { fmk_type_ = fmk_type; }
InsertOp(const std::string & op_name)56   void InsertOp(const std::string &op_name) { not_support_ops_.insert(op_name); }
PrintOps()57   void PrintOps() const {
58     if (!not_support_ops_.empty()) {
59       MS_LOG(ERROR) << "===========================================";
60       MS_LOG(ERROR) << "UNSUPPORTED OP LIST:";
61       for (auto &op_name : not_support_ops_) {
62         MS_LOG(ERROR) << "FMKTYPE: " << fmk_type_ << ", OP TYPE: " << op_name;
63       }
64       MS_LOG(ERROR) << "===========================================";
65     }
66   }
67 
68  private:
69   NotSupportOp() = default;
70   virtual ~NotSupportOp() = default;
71   std::set<std::string> not_support_ops_;
72   std::string fmk_type_;
73 };
74 
75 class ConverterContext {
76  public:
GetInstance()77   static ConverterContext *GetInstance() {
78     static ConverterContext converter_context;
79     return &converter_context;
80   }
81 
UpdateGraphInputDType(int32_t index,int32_t dtype)82   void UpdateGraphInputDType(int32_t index, int32_t dtype) { graph_input_data_type_map_[index] = dtype; }
GetGraphInputDType(int32_t index)83   int32_t GetGraphInputDType(int32_t index) const {
84     if (graph_input_data_type_map_.find(index) == graph_input_data_type_map_.end()) {
85       return TypeId::kTypeUnknown;
86     }
87     return graph_input_data_type_map_.at(index);
88   }
89 
UpdateGraphOutputDType(int32_t index,int32_t dtype)90   void UpdateGraphOutputDType(int32_t index, int32_t dtype) { graph_output_data_type_map_[index] = dtype; }
GetGraphOutputDType(int32_t index)91   int32_t GetGraphOutputDType(int32_t index) const {
92     if (graph_output_data_type_map_.find(index) == graph_output_data_type_map_.end()) {
93       return TypeId::kTypeUnknown;
94     }
95     return graph_output_data_type_map_.at(index);
96   }
97 
UpdateGraphInputTensorShape(const std::string & tensor_name,const std::vector<int64_t> & shape)98   void UpdateGraphInputTensorShape(const std::string &tensor_name, const std::vector<int64_t> &shape) {
99     graph_input_tensor_shape_map_[tensor_name] = shape;
100   }
GetGraphInputTensorShape(const std::string & tensor_name)101   std::vector<int64_t> GetGraphInputTensorShape(const std::string &tensor_name) const {
102     if (graph_input_tensor_shape_map_.find(tensor_name) == graph_input_tensor_shape_map_.end()) {
103       return {};
104     }
105     return graph_input_tensor_shape_map_.at(tensor_name);
106   }
GetGraphInputTensorShapeMapSize()107   size_t GetGraphInputTensorShapeMapSize() { return graph_input_tensor_shape_map_.size(); }
108 
SetGraphOutputTensorNames(const std::vector<std::string> & output_names)109   void SetGraphOutputTensorNames(const std::vector<std::string> &output_names) {
110     graph_output_tensor_names_ = output_names;
111   }
112 
GetGraphOutputTensorNames()113   const std::vector<std::string> GetGraphOutputTensorNames() const { return graph_output_tensor_names_; }
114 
115  private:
ConverterContext()116   ConverterContext() {}
117   virtual ~ConverterContext() = default;
118   std::map<int32_t, int32_t> graph_input_data_type_map_;
119   std::map<int32_t, int32_t> graph_output_data_type_map_;
120   std::map<std::string, std::vector<int64_t>> graph_input_tensor_shape_map_;
121   std::vector<std::string> graph_output_tensor_names_;
122 };
123 }  // namespace lite
124 }  // namespace mindspore
125 #endif  // MINDSPORE_LITE_TOOLS_CONVERTER_RETURN_CODE_H
126