• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_SRC_RUNTIME_CXX_API_MODEL_MODEL_IMPL_H_
18 #define MINDSPORE_LITE_SRC_RUNTIME_CXX_API_MODEL_MODEL_IMPL_H_
19 
20 #include <functional>
21 #include <map>
22 #include <string>
23 #include <vector>
24 #include <memory>
25 #include <utility>
26 #include <unordered_map>
27 #include "include/api/model.h"
28 #include "include/api/context.h"
29 #include "include/api/cell.h"
30 #include "src/litert/cxx_api/graph/graph_data.h"
31 #include "src/litert/inner_context.h"
32 #include "src/litert/lite_session.h"
33 #include "include/train/train_loop_callback.h"
34 
35 template <class T>
clearVectorOfPointers(std::vector<T> * v)36 void clearVectorOfPointers(std::vector<T> *v) {
37   if (v != nullptr) {
38     for (typename std::vector<T>::iterator it = v->begin(); it != v->end(); ++it) {
39       delete (*it);
40     }
41     v->clear();
42   }
43 }
44 
45 namespace mindspore {
46 
47 typedef std::shared_ptr<lite::LiteSession>(CreateTrainSessionProto)(std::shared_ptr<Graph::GraphData> graph_data,
48                                                                     std::shared_ptr<TrainCfg> cfg,
49                                                                     const std::shared_ptr<lite::InnerContext> &context);
50 MS_API CreateTrainSessionProto *CreateTrainSessionCallbackHolder(CreateTrainSessionProto *proto = nullptr);
51 
52 namespace session {
53 class Metrics;
54 class TrainLoopCallBack;
55 }  // namespace session
56 
57 class ModelImpl {
58  public:
ModelImpl()59   ModelImpl() : graph_(nullptr), session_(nullptr), context_(nullptr) {}
60   ~ModelImpl() = default;
61 
62   Status Build();
63   Status Build(const void *model_data, size_t data_size, ModelType model_type,
64                const std::shared_ptr<Context> &model_context);
65   Status Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context);
66   Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
67   Status UpdateWeights(const std::vector<MSTensor> &new_weights);
68 
69   Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs, const MSKernelCallBack &before,
70                  const MSKernelCallBack &after);
71 
72   Status Predict(const MSKernelCallBack &before, const MSKernelCallBack &after);
73 
74 #if defined(ENABLE_PRE_INFERENCE) && defined(__linux__) && !defined(Debug)
75   // note that: BuildAndRun interface is used for pre-build and pre-inferenence in child process.
76   Status BuildAndRun(const void *model_data, size_t data_size, ModelType model_type,
77                      const std::shared_ptr<Context> &model_context);
78   Status BuildAndRun(const std::string &model_path, ModelType model_type,
79                      const std::shared_ptr<Context> &model_context);
80   Status BuildAndRun();
81 
82   bool IsEnablePreInference();
83 #endif
84   lite::LiteSession *CreateLiteSession(const std::shared_ptr<lite::InnerContext> &context);
85 
86   Status LoadConfig(const std::string &config_path);
87   Status UpdateConfig(const std::string &section, const std::pair<std::string, std::string> &config);
88   std::vector<MSTensor> GetInputs();
89   std::vector<MSTensor> GetOutputs();
90   std::vector<MSTensor> GetGradients() const;
91   Status ApplyGradients(const std::vector<MSTensor> &gradients);
92   std::vector<MSTensor> GetFeatureMaps() const;
93   std::vector<MSTensor> GetTrainableParams() const;
94   Status UpdateFeatureMaps(const std::vector<MSTensor> &new_weights);
95   std::vector<MSTensor> GetOptimizerParams() const;
96   Status SetOptimizerParams(const std::vector<MSTensor> &params);
97   MSTensor GetInputByTensorName(const std::string &name);
98   std::vector<std::string> GetOutputTensorNames();
99   MSTensor GetOutputByTensorName(const std::string &name);
100   std::vector<MSTensor> GetOutputsByNodeName(const std::string &name);
101   Status BindGLTexture2DMemory(const std::map<std::string, unsigned int> &inputGLTexture,
102                                std::map<std::string, unsigned int> *outputGLTexture);
103 
104   static bool CheckModelSupport(const std::string &device_type, ModelType model_type);
105   bool IsTrainModel();
106   Status SetupVirtualBatch(int virtual_batch_multiplier, float lr, float momentum);
107   Status SetLearningRate(float learning_rate);
108   float GetLearningRate();
109   Status BuildTransferLearning(const std::shared_ptr<Graph> &backbone, const std::shared_ptr<Graph> &head);
110 
InitMetrics(const std::vector<Metrics * > metrics)111   Status InitMetrics(const std::vector<Metrics *> metrics) {
112     metrics_ = metrics;
113     return kSuccess;
114   }
GetMetrics()115   std::vector<Metrics *> GetMetrics() { return metrics_; }
GetSession()116   const lite::LiteSession *GetSession() const { return session_.get(); }
117   Status Finalize();
118 
119  protected:
120   // Utility methods
121   Status ConvertCallbacks(Model *model, std::vector<TrainCallBack *> *i_cbs,
122                           std::vector<lite::TrainLoopCallBack *> *o_cbs,
123                           std::vector<lite::TrainLoopCallBack *> *adapter_cbs);
124   Status PrepareMetrics(Model *model, std::vector<session::Metrics *> *o_ms,
125                         std::vector<session::Metrics *> *adapter_ms);
126 
127  private:
128   friend class Model;
129   friend class Serialization;
130   std::shared_ptr<Graph> graph_ = nullptr;
131   std::shared_ptr<lite::LiteSession> session_ = nullptr;
132   std::shared_ptr<Context> context_ = nullptr;
133   std::shared_ptr<TrainCfg> cfg_ = nullptr;
134   std::vector<Metrics *> metrics_;
SetGraph(const std::shared_ptr<Graph> & graph)135   void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; }
SetContext(const std::shared_ptr<Context> & context)136   void SetContext(const std::shared_ptr<Context> &context) { context_ = context; }
SetConfig(const std::shared_ptr<TrainCfg> cfg)137   void SetConfig(const std::shared_ptr<TrainCfg> cfg) { cfg_ = cfg; }
138   Status RunGraph(const MSKernelCallBack &before, const MSKernelCallBack &after);
139   bool IsEnableModelSharing(const std::string &model_path);
140   bool IsEnableModelSharing(const std::pair<const void *, size_t> &model_buff);
141   bool IsValidDoubleNum(const std::string &num_str);
142   int ModelDeObfuscate();
143   std::map<std::string, TypeId> execution_plan_;
144   std::map<std::string, std::map<std::string, std::string>> config_info_;
145 };
146 }  // namespace mindspore
147 
148 #endif  // MINDSPORE_LITE_SRC_RUNTIME_CXX_API_MODEL_MODEL_IMPL_H_
149