1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_CONVERTER_CONTEXT_H_ 18 #define MINDSPORE_LITE_TOOLS_CONVERTER_CONVERTER_CONTEXT_H_ 19 20 #include <string> 21 #include <set> 22 #include <map> 23 #include <vector> 24 #include "include/errorcode.h" 25 #include "src/common/log_adapter.h" 26 #include "ir/dtype/type_id.h" 27 #include "include/registry/converter_context.h" 28 29 namespace mindspore { 30 namespace lite { 31 class ReturnCode { 32 public: GetSingleReturnCode()33 static ReturnCode *GetSingleReturnCode() { 34 static ReturnCode return_code; 35 return &return_code; 36 } UpdateReturnCode(STATUS status)37 void UpdateReturnCode(STATUS status) { 38 if (status_code_ == RET_OK) { 39 status_code_ = status; 40 } 41 } status_code()42 STATUS status_code() const { return status_code_; } 43 44 private: 45 ReturnCode() = default; 46 virtual ~ReturnCode() = default; 47 int status_code_ = RET_OK; 48 }; 49 50 class NotSupportOp { 51 public: GetInstance()52 static NotSupportOp *GetInstance() { 53 static NotSupportOp not_support_op; 54 return ¬_support_op; 55 } set_fmk_type(const std::string & fmk_type)56 void set_fmk_type(const std::string &fmk_type) { fmk_type_ = fmk_type; } InsertOp(const std::string & op_name)57 void InsertOp(const std::string &op_name) { (void)not_support_ops_.insert(op_name); } PrintOps()58 void PrintOps() const { 59 if (!not_support_ops_.empty()) { 60 MS_LOG(ERROR) << "==========================================="; 61 MS_LOG(ERROR) << "UNSUPPORTED OP LIST:"; 62 for (auto &op_name : not_support_ops_) { 63 MS_LOG(ERROR) << "FMKTYPE: " << fmk_type_ << ", OP TYPE: " << op_name; 64 } 65 MS_LOG(ERROR) << "==========================================="; 66 } 67 } 68 69 private: 70 NotSupportOp() = default; 71 virtual ~NotSupportOp() = default; 72 std::set<std::string> not_support_ops_; 73 std::string fmk_type_; 74 }; 75 76 class ConverterInnerContext { 77 public: GetInstance()78 static ConverterInnerContext *GetInstance() { 79 static ConverterInnerContext converter_context; 80 return &converter_context; 81 } 82 UpdateGraphInputDType(int32_t index,int32_t dtype)83 void UpdateGraphInputDType(int32_t index, int32_t dtype) { graph_input_data_type_map_[index] = dtype; } GetGraphInputDType(int32_t index)84 int32_t GetGraphInputDType(int32_t index) const { 85 if (graph_input_data_type_map_.find(index) == graph_input_data_type_map_.end()) { 86 return TypeId::kTypeUnknown; 87 } 88 return graph_input_data_type_map_.at(index); 89 } 90 UpdateGraphOutputDType(int32_t index,int32_t dtype)91 void UpdateGraphOutputDType(int32_t index, int32_t dtype) { graph_output_data_type_map_[index] = dtype; } GetGraphOutputDType(int32_t index)92 int32_t GetGraphOutputDType(int32_t index) const { 93 if (graph_output_data_type_map_.find(index) == graph_output_data_type_map_.end()) { 94 return TypeId::kTypeUnknown; 95 } 96 return graph_output_data_type_map_.at(index); 97 } 98 UpdateGraphInputTensorShape(const std::string & tensor_name,const std::vector<int64_t> & shape)99 void UpdateGraphInputTensorShape(const std::string &tensor_name, const std::vector<int64_t> &shape) { 100 graph_input_tensor_shape_map_[tensor_name] = shape; 101 MS_LOG(INFO) << "Update shape of input " << tensor_name << " to " << shape; 102 } GetGraphInputTensorShape(const std::string & tensor_name)103 std::vector<int64_t> GetGraphInputTensorShape(const std::string &tensor_name) const { 104 if (graph_input_tensor_shape_map_.find(tensor_name) == graph_input_tensor_shape_map_.end()) { 105 return {}; 106 } 107 return graph_input_tensor_shape_map_.at(tensor_name); 108 } GetGraphInputTensorShapeMapSize()109 size_t GetGraphInputTensorShapeMapSize() const { return graph_input_tensor_shape_map_.size(); } 110 SetGraphOutputTensorNames(const std::vector<std::string> & output_names)111 void SetGraphOutputTensorNames(const std::vector<std::string> &output_names) { 112 graph_output_tensor_names_ = output_names; 113 } 114 GetGraphOutputTensorNames()115 const std::vector<std::string> GetGraphOutputTensorNames() const { return graph_output_tensor_names_; } 116 SetExternalUsedConfigInfos(const std::string & section,const std::map<std::string,std::string> & external_infos)117 void SetExternalUsedConfigInfos(const std::string §ion, 118 const std::map<std::string, std::string> &external_infos) { 119 for (auto const &external_info : external_infos) { 120 if (external_used_config_infos_[section].find(external_info.first) != 121 external_used_config_infos_[section].end()) { 122 MS_LOG(WARNING) << "This content " << external_info.first 123 << " has been saved. Now the value will be overwrite."; 124 } 125 external_used_config_infos_[section][external_info.first] = external_info.second; 126 } 127 } 128 GetExternalUsedConfigInfos()129 const std::map<std::string, std::map<std::string, std::string>> &GetExternalUsedConfigInfos() const { 130 return external_used_config_infos_; 131 } 132 SetTargetDevice(const std::string & target_device)133 void SetTargetDevice(const std::string &target_device) { target_device_ = target_device; } GetTargetDevice()134 std::string GetTargetDevice() const { return target_device_; } 135 Free()136 void Free() { 137 graph_input_data_type_map_.clear(); 138 graph_output_data_type_map_.clear(); 139 graph_input_tensor_shape_map_.clear(); 140 graph_output_tensor_names_.clear(); 141 external_used_config_infos_.clear(); 142 target_device_ = ""; 143 } 144 145 private: ConverterInnerContext()146 ConverterInnerContext() { 147 (void)external_used_config_infos_.emplace(mindspore::converter::KCommonQuantParam, 148 std::map<std::string, std::string>{}); 149 (void)external_used_config_infos_.emplace(mindspore::converter::KFullQuantParam, 150 std::map<std::string, std::string>{}); 151 (void)external_used_config_infos_.emplace(mindspore::converter::KDataPreProcess, 152 std::map<std::string, std::string>{}); 153 (void)external_used_config_infos_.emplace(mindspore::converter::KMixBitWeightQuantParam, 154 std::map<std::string, std::string>{}); 155 } 156 virtual ~ConverterInnerContext() = default; 157 std::map<int32_t, int32_t> graph_input_data_type_map_; 158 std::map<int32_t, int32_t> graph_output_data_type_map_; 159 std::map<std::string, std::vector<int64_t>> graph_input_tensor_shape_map_; 160 std::vector<std::string> graph_output_tensor_names_; 161 std::map<std::string, std::map<std::string, std::string>> external_used_config_infos_; 162 std::string target_device_; 163 }; 164 } // namespace lite 165 } // namespace mindspore 166 #endif // MINDSPORE_LITE_TOOLS_CONVERTER_CONVERTER_CONTEXT_H_ 167