• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_LITE_SRC_TRAIN_TRAIN_EXPORT_H_
17 #define MINDSPORE_LITE_SRC_TRAIN_TRAIN_EXPORT_H_
18 #include <string>
19 #include <vector>
20 #include <memory>
21 #include <map>
22 #include <unordered_map>
23 #include <utility>
24 #include <set>
25 #include "schema/inner/model_generated.h"
26 #include "src/executor/kernel_exec.h"
27 #include "src/litert/lite_model.h"
28 #include "include/train/train_cfg.h"
29 
30 namespace mindspore {
31 #ifndef _STUB
32 namespace schema {
33 struct CNodeT;
34 struct TensorT;
35 struct MetaGraphT;
36 }  // namespace schema
37 #endif
38 namespace lite {
39 struct tensor_info {
40   size_t input_index;
41   OpParameter *op_parameter;
42 };
43 
44 class TrainExport {
45  public:
TrainExport(const std::string file_name)46   explicit TrainExport(const std::string file_name) : file_name_(file_name) {}
TrainExport(Buffer * model_buffer)47   explicit TrainExport(Buffer *model_buffer) : model_buffer_(model_buffer) {}
48   virtual ~TrainExport();
49   int ExportNet(const std::vector<mindspore::kernel::KernelExec *> &kernels,
50                 const std::vector<mindspore::lite::Tensor *> &tensors,
51                 const std::vector<mindspore::lite::Tensor *> const_folded_output,
52                 const std::vector<std::string> &output_names, const Model *model, QuantizationType quant_type,
53                 const Model *bb_model = nullptr);
54   int ExportInit(const std::string model_name, std::string version);
55   int SaveToFile();
56   int SaveToBuffer();
57   int SaveWeightsToFile(bool enable_fp16 = false, const std::vector<std::string> &changeable_weights_name = {});
set_connect(const std::unordered_map<size_t,size_t> & map)58   void set_connect(const std::unordered_map<size_t, size_t> &map) { connect_ = map; }
59   int LoadModel(void *buf, size_t buf_size);
60   int AddTransformNode();
61   int TrainModelFusion();
62   int TrainModelDrop();
63   int SaveModel(lite::Model *model, const std::string &file_name);
64   int SaveModel(lite::Model *model, Buffer *model_buffer);
65 
66  protected:
67   virtual std::vector<uint8_t> CreateData(const mindspore::lite::Tensor *tensor);
68 
69  private:
70   Buffer *model_buffer_ = nullptr;
71   std::string file_name_;
72   schema::MetaGraphT *meta_graph_ = nullptr;
73   std::vector<size_t> out_idx_;
74   std::map<size_t, size_t> remap_;
75   std::unordered_map<size_t, size_t> connect_;  // connection map (backbone tenor id-> head tensor id)
76   bool IsNodeNonDepend(const std::unique_ptr<schema::CNodeT> &node, const std::vector<size_t> &sinked_tensor_idxes);
77   int TopologicalSort();
78   void PrepareRemap(int offset);
79   LiteGraph::Node *FindNode(const mindspore::kernel::KernelExec *kernel, const Model *model);
80   std::unique_ptr<schema::TensorT> CreateTensor(const Tensor *tensor,
81                                                 const std::vector<mindspore::lite::Tensor *> const_folded_output,
82                                                 schema::Tensor *scTensor, int preferred_dim,
83                                                 const int tensor_quant_type);
84   std::unique_ptr<schema::CNodeT> CreateCNode(const mindspore::kernel::KernelExec *kernel,
85                                               std::vector<uint32_t> inputIndex, std::vector<uint32_t> outputIndex,
86                                               const Model *model);
87   bool IsInputTensor(const schema::TensorT &t);
88   int CreateAndAddCNode(const mindspore::kernel::KernelExec *kernel, std::vector<uint32_t> inputIndex,
89                         std::vector<uint32_t> outputIndex, const Model *model);
90   std::unique_ptr<schema::CNodeT> CreateTransformNode(std::vector<uint32_t> inputIndex,
91                                                       std::vector<uint32_t> outputIndex, size_t id);
92   std::unique_ptr<schema::TensorT> CreateTransformTensor(size_t id);
93   std::unique_ptr<schema::TensorT> CreateTransformConst(size_t last_id);
94   int AddTransform();
95   bool NeedQuantization(const mindspore::lite::Tensor *tensor, const int tensor_quant_type);
96   int FindSchemaTensorByName(const std::vector<uint32_t> &search_indices, const std::string &search_name,
97                              size_t *target_index);
98   int KeepGraphInputsInOrder(const Model *model);
99   int ExportTensor(const Model *model, const std::vector<mindspore::lite::Tensor *> &tensors, int offset,
100                    const std::vector<mindspore::lite::Tensor *> const_folded_output,
101                    const std::vector<std::pair<size_t, tensor_info>> &map_index,
102                    const std::vector<std::string> &output_names, const std::set<size_t> &out_set);
103   virtual int QuantTensorData(schema::TensorT *dest_tensor, const mindspore::lite::Tensor *src_tensor,
104                               int preferred_dim);
105   mindspore::schema::QuantType GetNodeQuantType(const mindspore::kernel::KernelExec *kernel);
106   void TagQuantizedNodes();
107   QuantizationType quant_type_;
108 };
109 };  // namespace lite
110 }  // namespace mindspore
111 
112 #endif  // MINDSPORE_LITE_SRC_TRAIN_TRAIN_EXPORT_H_
113