• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_SRC_EXTENDRT_MINDIR_LOADER_MODEL_LOADER_H_
18 #define MINDSPORE_LITE_SRC_EXTENDRT_MINDIR_LOADER_MODEL_LOADER_H_
19 
20 #include <memory>
21 
22 #include "extendrt/mindir_loader/abstract_base_model.h"
23 
24 namespace mindspore::infer {
25 class ModelLoader {
26  public:
27   virtual AbstractBaseModel *ImportModel(const char *model_buf, size_t size, bool take_buf) = 0;
28 
29  protected:
30   virtual int InitModelBuffer(AbstractBaseModel *model, const char *model_buf, size_t size, bool take_buf);
31 };
32 
33 class ModelLoaderRegistry {
34  public:
35   ModelLoaderRegistry();
36   virtual ~ModelLoaderRegistry();
37 
38   static ModelLoaderRegistry *GetInstance();
39 
RegModelLoader(mindspore::ModelType model_type,std::function<std::shared_ptr<ModelLoader> ()> creator)40   void RegModelLoader(mindspore::ModelType model_type, std::function<std::shared_ptr<ModelLoader>()> creator) {
41     model_loader_map_[model_type] = creator;
42   }
43 
GetModelLoader(mindspore::ModelType model_type)44   std::shared_ptr<ModelLoader> GetModelLoader(mindspore::ModelType model_type) {
45     auto it = model_loader_map_.find(model_type);
46     if (it == model_loader_map_.end()) {
47       return nullptr;
48     }
49     return it->second();
50   }
51 
52  private:
53   mindspore::HashMap<mindspore::ModelType, std::function<std::shared_ptr<ModelLoader>()>> model_loader_map_;
54 };
55 
56 class ModelLoaderRegistrar {
57  public:
ModelLoaderRegistrar(const mindspore::ModelType & model_type,std::function<std::shared_ptr<ModelLoader> ()> creator)58   ModelLoaderRegistrar(const mindspore::ModelType &model_type, std::function<std::shared_ptr<ModelLoader>()> creator) {
59     ModelLoaderRegistry::GetInstance()->RegModelLoader(model_type, creator);
60   }
61   ~ModelLoaderRegistrar() = default;
62 };
63 
64 #define REG_MODEL_LOADER(model_type, model_loader_creator) \
65   static ModelLoaderRegistrar g_##model_type##model_loader##ModelLoader(model_type, model_loader_creator);
66 }  // namespace mindspore::infer
67 
68 #endif  // MINDSPORE_LITE_SRC_EXTENDRT_MINDIR_LOADER_MODEL_LOADER_H_
69