• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_CONVERTER_CONTEXT_H_
18 #define MINDSPORE_LITE_TOOLS_CONVERTER_CONVERTER_CONTEXT_H_
19 
20 #include <string>
21 #include <set>
22 #include <map>
23 #include <vector>
24 #include "include/errorcode.h"
25 #include "src/common/log_adapter.h"
26 #include "ir/dtype/type_id.h"
27 #include "include/registry/converter_context.h"
28 
29 namespace mindspore {
30 namespace lite {
31 class ReturnCode {
32  public:
GetSingleReturnCode()33   static ReturnCode *GetSingleReturnCode() {
34     static ReturnCode return_code;
35     return &return_code;
36   }
UpdateReturnCode(STATUS status)37   void UpdateReturnCode(STATUS status) {
38     if (status_code_ == RET_OK) {
39       status_code_ = status;
40     }
41   }
status_code()42   STATUS status_code() const { return status_code_; }
43 
44  private:
45   ReturnCode() = default;
46   virtual ~ReturnCode() = default;
47   int status_code_ = RET_OK;
48 };
49 
50 class NotSupportOp {
51  public:
GetInstance()52   static NotSupportOp *GetInstance() {
53     static NotSupportOp not_support_op;
54     return &not_support_op;
55   }
set_fmk_type(const std::string & fmk_type)56   void set_fmk_type(const std::string &fmk_type) { fmk_type_ = fmk_type; }
InsertOp(const std::string & op_name)57   void InsertOp(const std::string &op_name) { (void)not_support_ops_.insert(op_name); }
PrintOps()58   void PrintOps() const {
59     if (!not_support_ops_.empty()) {
60       MS_LOG(ERROR) << "===========================================";
61       MS_LOG(ERROR) << "UNSUPPORTED OP LIST:";
62       for (auto &op_name : not_support_ops_) {
63         MS_LOG(ERROR) << "FMKTYPE: " << fmk_type_ << ", OP TYPE: " << op_name;
64       }
65       MS_LOG(ERROR) << "===========================================";
66     }
67   }
68 
69  private:
70   NotSupportOp() = default;
71   virtual ~NotSupportOp() = default;
72   std::set<std::string> not_support_ops_;
73   std::string fmk_type_;
74 };
75 
76 class ConverterInnerContext {
77  public:
GetInstance()78   static ConverterInnerContext *GetInstance() {
79     static ConverterInnerContext converter_context;
80     return &converter_context;
81   }
82 
UpdateGraphInputDType(int32_t index,int32_t dtype)83   void UpdateGraphInputDType(int32_t index, int32_t dtype) { graph_input_data_type_map_[index] = dtype; }
GetGraphInputDType(int32_t index)84   int32_t GetGraphInputDType(int32_t index) const {
85     if (graph_input_data_type_map_.find(index) == graph_input_data_type_map_.end()) {
86       return TypeId::kTypeUnknown;
87     }
88     return graph_input_data_type_map_.at(index);
89   }
90 
UpdateGraphOutputDType(int32_t index,int32_t dtype)91   void UpdateGraphOutputDType(int32_t index, int32_t dtype) { graph_output_data_type_map_[index] = dtype; }
GetGraphOutputDType(int32_t index)92   int32_t GetGraphOutputDType(int32_t index) const {
93     if (graph_output_data_type_map_.find(index) == graph_output_data_type_map_.end()) {
94       return TypeId::kTypeUnknown;
95     }
96     return graph_output_data_type_map_.at(index);
97   }
98 
UpdateGraphInputTensorShape(const std::string & tensor_name,const std::vector<int64_t> & shape)99   void UpdateGraphInputTensorShape(const std::string &tensor_name, const std::vector<int64_t> &shape) {
100     graph_input_tensor_shape_map_[tensor_name] = shape;
101     MS_LOG(INFO) << "Update shape of input " << tensor_name << " to " << shape;
102   }
GetGraphInputTensorShape(const std::string & tensor_name)103   std::vector<int64_t> GetGraphInputTensorShape(const std::string &tensor_name) const {
104     if (graph_input_tensor_shape_map_.find(tensor_name) == graph_input_tensor_shape_map_.end()) {
105       return {};
106     }
107     return graph_input_tensor_shape_map_.at(tensor_name);
108   }
GetGraphInputTensorShapeMapSize()109   size_t GetGraphInputTensorShapeMapSize() const { return graph_input_tensor_shape_map_.size(); }
110 
SetGraphOutputTensorNames(const std::vector<std::string> & output_names)111   void SetGraphOutputTensorNames(const std::vector<std::string> &output_names) {
112     graph_output_tensor_names_ = output_names;
113   }
114 
GetGraphOutputTensorNames()115   const std::vector<std::string> GetGraphOutputTensorNames() const { return graph_output_tensor_names_; }
116 
SetExternalUsedConfigInfos(const std::string & section,const std::map<std::string,std::string> & external_infos)117   void SetExternalUsedConfigInfos(const std::string &section,
118                                   const std::map<std::string, std::string> &external_infos) {
119     for (auto const &external_info : external_infos) {
120       if (external_used_config_infos_[section].find(external_info.first) !=
121           external_used_config_infos_[section].end()) {
122         MS_LOG(WARNING) << "This content " << external_info.first
123                         << " has been saved. Now the value will be overwrite.";
124       }
125       external_used_config_infos_[section][external_info.first] = external_info.second;
126     }
127   }
128 
GetExternalUsedConfigInfos()129   const std::map<std::string, std::map<std::string, std::string>> &GetExternalUsedConfigInfos() const {
130     return external_used_config_infos_;
131   }
132 
SetTargetDevice(const std::string & target_device)133   void SetTargetDevice(const std::string &target_device) { target_device_ = target_device; }
GetTargetDevice()134   std::string GetTargetDevice() const { return target_device_; }
135 
Free()136   void Free() {
137     graph_input_data_type_map_.clear();
138     graph_output_data_type_map_.clear();
139     graph_input_tensor_shape_map_.clear();
140     graph_output_tensor_names_.clear();
141     external_used_config_infos_.clear();
142     target_device_ = "";
143   }
144 
145  private:
ConverterInnerContext()146   ConverterInnerContext() {
147     (void)external_used_config_infos_.emplace(mindspore::converter::KCommonQuantParam,
148                                               std::map<std::string, std::string>{});
149     (void)external_used_config_infos_.emplace(mindspore::converter::KFullQuantParam,
150                                               std::map<std::string, std::string>{});
151     (void)external_used_config_infos_.emplace(mindspore::converter::KDataPreProcess,
152                                               std::map<std::string, std::string>{});
153     (void)external_used_config_infos_.emplace(mindspore::converter::KMixBitWeightQuantParam,
154                                               std::map<std::string, std::string>{});
155   }
156   virtual ~ConverterInnerContext() = default;
157   std::map<int32_t, int32_t> graph_input_data_type_map_;
158   std::map<int32_t, int32_t> graph_output_data_type_map_;
159   std::map<std::string, std::vector<int64_t>> graph_input_tensor_shape_map_;
160   std::vector<std::string> graph_output_tensor_names_;
161   std::map<std::string, std::map<std::string, std::string>> external_used_config_infos_;
162   std::string target_device_;
163 };
164 }  // namespace lite
165 }  // namespace mindspore
166 #endif  // MINDSPORE_LITE_TOOLS_CONVERTER_CONVERTER_CONTEXT_H_
167