• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_SRC_CXX_API_MODEL_MODEL_IMPL_H_
18 #define MINDSPORE_LITE_SRC_CXX_API_MODEL_MODEL_IMPL_H_
19 
20 #include <functional>
21 #include <map>
22 #include <string>
23 #include <vector>
24 #include <memory>
25 #include <utility>
26 #include <unordered_map>
27 #include "include/api/model.h"
28 #include "include/api/context.h"
29 #include "include/api/cell.h"
30 #include "include/lite_session.h"
31 #include "src/cxx_api/graph/graph_data.h"
32 #include "src/inner_context.h"
33 #include "src/lite_session.h"
34 
35 template <class T>
clearVectorOfPointers(std::vector<T> * v)36 void clearVectorOfPointers(std::vector<T> *v) {
37   if (v != nullptr) {
38     for (typename std::vector<T>::iterator it = v->begin(); it != v->end(); ++it) {
39       delete (*it);
40     }
41     v->clear();
42   }
43 }
44 
45 namespace mindspore {
46 
47 typedef std::shared_ptr<lite::LiteSession>(CreateTrainSessionProto)(std::shared_ptr<Graph::GraphData> graph_data,
48                                                                     std::shared_ptr<TrainCfg> cfg,
49                                                                     lite::InnerContext *context);
50 CreateTrainSessionProto *CreateTrainSessionCallbackHolder(CreateTrainSessionProto *proto = nullptr);
51 
52 namespace session {
53 class Metrics;
54 class TrainLoopCallBack;
55 }  // namespace session
56 
57 class ModelImpl {
58  public:
ModelImpl()59   ModelImpl() : graph_(nullptr), session_(nullptr), context_(nullptr) {}
60   ~ModelImpl() = default;
61 
62   Status Build();
63   Status Build(const void *model_data, size_t data_size, ModelType model_type,
64                const std::shared_ptr<Context> &model_context);
65   Status Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context);
66   Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
67 
68   Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs, const MSKernelCallBack &before,
69                  const MSKernelCallBack &after);
70 
71   lite::LiteSession *CreateLiteSession(lite::InnerContext *context);
72 
73   Status LoadConfig(const std::string &config_path);
74   std::vector<MSTensor> GetInputs();
75   std::vector<MSTensor> GetOutputs();
76   std::vector<MSTensor> GetGradients() const;
77   Status ApplyGradients(const std::vector<MSTensor> &gradients);
78   std::vector<MSTensor> GetOptimizerParams() const;
79   Status SetOptimizerParams(const std::vector<MSTensor> &params);
80   MSTensor GetInputByTensorName(const std::string &name);
81   std::vector<std::string> GetOutputTensorNames();
82   MSTensor GetOutputByTensorName(const std::string &name);
83   std::vector<MSTensor> GetOutputsByNodeName(const std::string &name);
84 
85   static bool CheckModelSupport(const std::string &device_type, ModelType model_type);
86   bool IsTrainModel();
87 
InitMetrics(std::vector<Metrics * > metrics)88   Status InitMetrics(std::vector<Metrics *> metrics) {
89     metrics_ = metrics;
90     return kSuccess;
91   }
GetMetrics()92   std::vector<Metrics *> GetMetrics() { return metrics_; }
GetSession()93   const session::LiteSession *GetSession() const { return session_.get(); }
94 
95  protected:
96   // Utility methods
97   Status ConvertCallbacks(Model *model, std::vector<TrainCallBack *> *i_cbs,
98                           std::vector<session::TrainLoopCallBack *> *o_cbs,
99                           std::vector<session::TrainLoopCallBack *> *adapter_cbs);
100   Status PrepareMetrics(Model *model, std::vector<session::Metrics *> *o_ms,
101                         std::vector<session::Metrics *> *adapter_ms);
102 
103  private:
104   friend class Model;
105   friend class Serialization;
106   std::shared_ptr<Graph> graph_ = nullptr;
107   std::shared_ptr<lite::LiteSession> session_ = nullptr;
108   std::shared_ptr<Context> context_ = nullptr;
109   std::shared_ptr<TrainCfg> cfg_ = nullptr;
110   std::vector<Metrics *> metrics_;
SetGraph(const std::shared_ptr<Graph> & graph)111   void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; }
SetContext(const std::shared_ptr<Context> & context)112   void SetContext(const std::shared_ptr<Context> &context) { context_ = context; }
SetConfig(const std::shared_ptr<TrainCfg> cfg)113   void SetConfig(const std::shared_ptr<TrainCfg> cfg) { cfg_ = cfg; }
114   Status RunGraph(const MSKernelCallBack &before, const MSKernelCallBack &after);
115   std::map<std::string, TypeId> execution_plan_;
116 };
117 }  // namespace mindspore
118 
119 #endif  // MINDSPORE_LITE_SRC_CXX_API_MODEL_MODEL_IMPL_H_
120