1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_LITE_SRC_CXX_API_MODEL_MODEL_IMPL_H_
18 #define MINDSPORE_LITE_SRC_CXX_API_MODEL_MODEL_IMPL_H_
19
20 #include <functional>
21 #include <map>
22 #include <string>
23 #include <vector>
24 #include <memory>
25 #include <utility>
26 #include <unordered_map>
27 #include "include/api/model.h"
28 #include "include/api/context.h"
29 #include "include/api/cell.h"
30 #include "include/lite_session.h"
31 #include "src/cxx_api/graph/graph_data.h"
32 #include "src/inner_context.h"
33 #include "src/lite_session.h"
34
35 template <class T>
clearVectorOfPointers(std::vector<T> * v)36 void clearVectorOfPointers(std::vector<T> *v) {
37 if (v != nullptr) {
38 for (typename std::vector<T>::iterator it = v->begin(); it != v->end(); ++it) {
39 delete (*it);
40 }
41 v->clear();
42 }
43 }
44
45 namespace mindspore {
46
47 typedef std::shared_ptr<lite::LiteSession>(CreateTrainSessionProto)(std::shared_ptr<Graph::GraphData> graph_data,
48 std::shared_ptr<TrainCfg> cfg,
49 lite::InnerContext *context);
50 CreateTrainSessionProto *CreateTrainSessionCallbackHolder(CreateTrainSessionProto *proto = nullptr);
51
52 namespace session {
53 class Metrics;
54 class TrainLoopCallBack;
55 } // namespace session
56
57 class ModelImpl {
58 public:
ModelImpl()59 ModelImpl() : graph_(nullptr), session_(nullptr), context_(nullptr) {}
60 ~ModelImpl() = default;
61
62 Status Build();
63 Status Build(const void *model_data, size_t data_size, ModelType model_type,
64 const std::shared_ptr<Context> &model_context);
65 Status Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context);
66 Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
67
68 Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs, const MSKernelCallBack &before,
69 const MSKernelCallBack &after);
70
71 lite::LiteSession *CreateLiteSession(lite::InnerContext *context);
72
73 Status LoadConfig(const std::string &config_path);
74 std::vector<MSTensor> GetInputs();
75 std::vector<MSTensor> GetOutputs();
76 std::vector<MSTensor> GetGradients() const;
77 Status ApplyGradients(const std::vector<MSTensor> &gradients);
78 std::vector<MSTensor> GetOptimizerParams() const;
79 Status SetOptimizerParams(const std::vector<MSTensor> ¶ms);
80 MSTensor GetInputByTensorName(const std::string &name);
81 std::vector<std::string> GetOutputTensorNames();
82 MSTensor GetOutputByTensorName(const std::string &name);
83 std::vector<MSTensor> GetOutputsByNodeName(const std::string &name);
84
85 static bool CheckModelSupport(const std::string &device_type, ModelType model_type);
86 bool IsTrainModel();
87
InitMetrics(std::vector<Metrics * > metrics)88 Status InitMetrics(std::vector<Metrics *> metrics) {
89 metrics_ = metrics;
90 return kSuccess;
91 }
GetMetrics()92 std::vector<Metrics *> GetMetrics() { return metrics_; }
GetSession()93 const session::LiteSession *GetSession() const { return session_.get(); }
94
95 protected:
96 // Utility methods
97 Status ConvertCallbacks(Model *model, std::vector<TrainCallBack *> *i_cbs,
98 std::vector<session::TrainLoopCallBack *> *o_cbs,
99 std::vector<session::TrainLoopCallBack *> *adapter_cbs);
100 Status PrepareMetrics(Model *model, std::vector<session::Metrics *> *o_ms,
101 std::vector<session::Metrics *> *adapter_ms);
102
103 private:
104 friend class Model;
105 friend class Serialization;
106 std::shared_ptr<Graph> graph_ = nullptr;
107 std::shared_ptr<lite::LiteSession> session_ = nullptr;
108 std::shared_ptr<Context> context_ = nullptr;
109 std::shared_ptr<TrainCfg> cfg_ = nullptr;
110 std::vector<Metrics *> metrics_;
SetGraph(const std::shared_ptr<Graph> & graph)111 void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; }
SetContext(const std::shared_ptr<Context> & context)112 void SetContext(const std::shared_ptr<Context> &context) { context_ = context; }
SetConfig(const std::shared_ptr<TrainCfg> cfg)113 void SetConfig(const std::shared_ptr<TrainCfg> cfg) { cfg_ = cfg; }
114 Status RunGraph(const MSKernelCallBack &before, const MSKernelCallBack &after);
115 std::map<std::string, TypeId> execution_plan_;
116 };
117 } // namespace mindspore
118
119 #endif // MINDSPORE_LITE_SRC_CXX_API_MODEL_MODEL_IMPL_H_
120