1 /** 2 * Copyright 2021-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_SRC_EXTENDRT_MINDIR_LOADER_MODEL_LOADER_H_ 18 #define MINDSPORE_LITE_SRC_EXTENDRT_MINDIR_LOADER_MODEL_LOADER_H_ 19 20 #include <memory> 21 22 #include "extendrt/mindir_loader/abstract_base_model.h" 23 24 namespace mindspore::infer { 25 class ModelLoader { 26 public: 27 virtual AbstractBaseModel *ImportModel(const char *model_buf, size_t size, bool take_buf) = 0; 28 29 protected: 30 virtual int InitModelBuffer(AbstractBaseModel *model, const char *model_buf, size_t size, bool take_buf); 31 }; 32 33 class ModelLoaderRegistry { 34 public: 35 ModelLoaderRegistry(); 36 virtual ~ModelLoaderRegistry(); 37 38 static ModelLoaderRegistry *GetInstance(); 39 RegModelLoader(mindspore::ModelType model_type,std::function<std::shared_ptr<ModelLoader> ()> creator)40 void RegModelLoader(mindspore::ModelType model_type, std::function<std::shared_ptr<ModelLoader>()> creator) { 41 model_loader_map_[model_type] = creator; 42 } 43 GetModelLoader(mindspore::ModelType model_type)44 std::shared_ptr<ModelLoader> GetModelLoader(mindspore::ModelType model_type) { 45 auto it = model_loader_map_.find(model_type); 46 if (it == model_loader_map_.end()) { 47 return nullptr; 48 } 49 return it->second(); 50 } 51 52 private: 53 mindspore::HashMap<mindspore::ModelType, std::function<std::shared_ptr<ModelLoader>()>> model_loader_map_; 54 }; 55 56 class ModelLoaderRegistrar { 57 public: ModelLoaderRegistrar(const mindspore::ModelType & model_type,std::function<std::shared_ptr<ModelLoader> ()> creator)58 ModelLoaderRegistrar(const mindspore::ModelType &model_type, std::function<std::shared_ptr<ModelLoader>()> creator) { 59 ModelLoaderRegistry::GetInstance()->RegModelLoader(model_type, creator); 60 } 61 ~ModelLoaderRegistrar() = default; 62 }; 63 64 #define REG_MODEL_LOADER(model_type, model_loader_creator) \ 65 static ModelLoaderRegistrar g_##model_type##model_loader##ModelLoader(model_type, model_loader_creator); 66 } // namespace mindspore::infer 67 68 #endif // MINDSPORE_LITE_SRC_EXTENDRT_MINDIR_LOADER_MODEL_LOADER_H_ 69