1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_LITE_SRC_RUNTIME_CXX_API_MODEL_MODEL_IMPL_H_
18 #define MINDSPORE_LITE_SRC_RUNTIME_CXX_API_MODEL_MODEL_IMPL_H_
19
20 #include <functional>
21 #include <map>
22 #include <string>
23 #include <vector>
24 #include <memory>
25 #include <utility>
26 #include <unordered_map>
27 #include "include/api/model.h"
28 #include "include/api/context.h"
29 #include "include/api/cell.h"
30 #include "src/litert/cxx_api/graph/graph_data.h"
31 #include "src/litert/inner_context.h"
32 #include "src/litert/lite_session.h"
33 #include "include/train/train_loop_callback.h"
34
35 template <class T>
clearVectorOfPointers(std::vector<T> * v)36 void clearVectorOfPointers(std::vector<T> *v) {
37 if (v != nullptr) {
38 for (typename std::vector<T>::iterator it = v->begin(); it != v->end(); ++it) {
39 delete (*it);
40 }
41 v->clear();
42 }
43 }
44
45 namespace mindspore {
46
47 typedef std::shared_ptr<lite::LiteSession>(CreateTrainSessionProto)(std::shared_ptr<Graph::GraphData> graph_data,
48 std::shared_ptr<TrainCfg> cfg,
49 const std::shared_ptr<lite::InnerContext> &context);
50 MS_API CreateTrainSessionProto *CreateTrainSessionCallbackHolder(CreateTrainSessionProto *proto = nullptr);
51
52 namespace session {
53 class Metrics;
54 class TrainLoopCallBack;
55 } // namespace session
56
57 class ModelImpl {
58 public:
ModelImpl()59 ModelImpl() : graph_(nullptr), session_(nullptr), context_(nullptr) {}
60 ~ModelImpl() = default;
61
62 Status Build();
63 Status Build(const void *model_data, size_t data_size, ModelType model_type,
64 const std::shared_ptr<Context> &model_context);
65 Status Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context);
66 Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
67 Status UpdateWeights(const std::vector<MSTensor> &new_weights);
68
69 Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs, const MSKernelCallBack &before,
70 const MSKernelCallBack &after);
71
72 Status Predict(const MSKernelCallBack &before, const MSKernelCallBack &after);
73
74 #if defined(ENABLE_PRE_INFERENCE) && defined(__linux__) && !defined(Debug)
75 // note that: BuildAndRun interface is used for pre-build and pre-inferenence in child process.
76 Status BuildAndRun(const void *model_data, size_t data_size, ModelType model_type,
77 const std::shared_ptr<Context> &model_context);
78 Status BuildAndRun(const std::string &model_path, ModelType model_type,
79 const std::shared_ptr<Context> &model_context);
80 Status BuildAndRun();
81
82 bool IsEnablePreInference();
83 #endif
84 lite::LiteSession *CreateLiteSession(const std::shared_ptr<lite::InnerContext> &context);
85
86 Status LoadConfig(const std::string &config_path);
87 Status UpdateConfig(const std::string §ion, const std::pair<std::string, std::string> &config);
88 std::vector<MSTensor> GetInputs();
89 std::vector<MSTensor> GetOutputs();
90 std::vector<MSTensor> GetGradients() const;
91 Status ApplyGradients(const std::vector<MSTensor> &gradients);
92 std::vector<MSTensor> GetFeatureMaps() const;
93 std::vector<MSTensor> GetTrainableParams() const;
94 Status UpdateFeatureMaps(const std::vector<MSTensor> &new_weights);
95 std::vector<MSTensor> GetOptimizerParams() const;
96 Status SetOptimizerParams(const std::vector<MSTensor> ¶ms);
97 MSTensor GetInputByTensorName(const std::string &name);
98 std::vector<std::string> GetOutputTensorNames();
99 MSTensor GetOutputByTensorName(const std::string &name);
100 std::vector<MSTensor> GetOutputsByNodeName(const std::string &name);
101 Status BindGLTexture2DMemory(const std::map<std::string, unsigned int> &inputGLTexture,
102 std::map<std::string, unsigned int> *outputGLTexture);
103
104 static bool CheckModelSupport(const std::string &device_type, ModelType model_type);
105 bool IsTrainModel();
106 Status SetupVirtualBatch(int virtual_batch_multiplier, float lr, float momentum);
107 Status SetLearningRate(float learning_rate);
108 float GetLearningRate();
109 Status BuildTransferLearning(const std::shared_ptr<Graph> &backbone, const std::shared_ptr<Graph> &head);
110
InitMetrics(const std::vector<Metrics * > metrics)111 Status InitMetrics(const std::vector<Metrics *> metrics) {
112 metrics_ = metrics;
113 return kSuccess;
114 }
GetMetrics()115 std::vector<Metrics *> GetMetrics() { return metrics_; }
GetSession()116 const lite::LiteSession *GetSession() const { return session_.get(); }
117 Status Finalize();
118
119 protected:
120 // Utility methods
121 Status ConvertCallbacks(Model *model, std::vector<TrainCallBack *> *i_cbs,
122 std::vector<lite::TrainLoopCallBack *> *o_cbs,
123 std::vector<lite::TrainLoopCallBack *> *adapter_cbs);
124 Status PrepareMetrics(Model *model, std::vector<session::Metrics *> *o_ms,
125 std::vector<session::Metrics *> *adapter_ms);
126
127 private:
128 friend class Model;
129 friend class Serialization;
130 std::shared_ptr<Graph> graph_ = nullptr;
131 std::shared_ptr<lite::LiteSession> session_ = nullptr;
132 std::shared_ptr<Context> context_ = nullptr;
133 std::shared_ptr<TrainCfg> cfg_ = nullptr;
134 std::vector<Metrics *> metrics_;
SetGraph(const std::shared_ptr<Graph> & graph)135 void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; }
SetContext(const std::shared_ptr<Context> & context)136 void SetContext(const std::shared_ptr<Context> &context) { context_ = context; }
SetConfig(const std::shared_ptr<TrainCfg> cfg)137 void SetConfig(const std::shared_ptr<TrainCfg> cfg) { cfg_ = cfg; }
138 Status RunGraph(const MSKernelCallBack &before, const MSKernelCallBack &after);
139 bool IsEnableModelSharing(const std::string &model_path);
140 bool IsEnableModelSharing(const std::pair<const void *, size_t> &model_buff);
141 bool IsValidDoubleNum(const std::string &num_str);
142 int ModelDeObfuscate();
143 std::map<std::string, TypeId> execution_plan_;
144 std::map<std::string, std::map<std::string, std::string>> config_info_;
145 };
146 } // namespace mindspore
147
148 #endif // MINDSPORE_LITE_SRC_RUNTIME_CXX_API_MODEL_MODEL_IMPL_H_
149