1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_LITE_SRC_TRAIN_TRAIN_EXPORT_H_ 17 #define MINDSPORE_LITE_SRC_TRAIN_TRAIN_EXPORT_H_ 18 #include <string> 19 #include <vector> 20 #include <memory> 21 #include <map> 22 #include <unordered_map> 23 #include <utility> 24 #include <set> 25 #include "schema/inner/model_generated.h" 26 #include "src/executor/kernel_exec.h" 27 #include "src/litert/lite_model.h" 28 #include "include/train/train_cfg.h" 29 30 namespace mindspore { 31 #ifndef _STUB 32 namespace schema { 33 struct CNodeT; 34 struct TensorT; 35 struct MetaGraphT; 36 } // namespace schema 37 #endif 38 namespace lite { 39 struct tensor_info { 40 size_t input_index; 41 OpParameter *op_parameter; 42 }; 43 44 class TrainExport { 45 public: TrainExport(const std::string file_name)46 explicit TrainExport(const std::string file_name) : file_name_(file_name) {} TrainExport(Buffer * model_buffer)47 explicit TrainExport(Buffer *model_buffer) : model_buffer_(model_buffer) {} 48 virtual ~TrainExport(); 49 int ExportNet(const std::vector<mindspore::kernel::KernelExec *> &kernels, 50 const std::vector<mindspore::lite::Tensor *> &tensors, 51 const std::vector<mindspore::lite::Tensor *> const_folded_output, 52 const std::vector<std::string> &output_names, const Model *model, QuantizationType quant_type, 53 const Model *bb_model = nullptr); 54 int ExportInit(const std::string model_name, std::string version); 55 int SaveToFile(); 56 int SaveToBuffer(); 57 int SaveWeightsToFile(bool enable_fp16 = false, const std::vector<std::string> &changeable_weights_name = {}); set_connect(const std::unordered_map<size_t,size_t> & map)58 void set_connect(const std::unordered_map<size_t, size_t> &map) { connect_ = map; } 59 int LoadModel(void *buf, size_t buf_size); 60 int AddTransformNode(); 61 int TrainModelFusion(); 62 int TrainModelDrop(); 63 int SaveModel(lite::Model *model, const std::string &file_name); 64 int SaveModel(lite::Model *model, Buffer *model_buffer); 65 66 protected: 67 virtual std::vector<uint8_t> CreateData(const mindspore::lite::Tensor *tensor); 68 69 private: 70 Buffer *model_buffer_ = nullptr; 71 std::string file_name_; 72 schema::MetaGraphT *meta_graph_ = nullptr; 73 std::vector<size_t> out_idx_; 74 std::map<size_t, size_t> remap_; 75 std::unordered_map<size_t, size_t> connect_; // connection map (backbone tenor id-> head tensor id) 76 bool IsNodeNonDepend(const std::unique_ptr<schema::CNodeT> &node, const std::vector<size_t> &sinked_tensor_idxes); 77 int TopologicalSort(); 78 void PrepareRemap(int offset); 79 LiteGraph::Node *FindNode(const mindspore::kernel::KernelExec *kernel, const Model *model); 80 std::unique_ptr<schema::TensorT> CreateTensor(const Tensor *tensor, 81 const std::vector<mindspore::lite::Tensor *> const_folded_output, 82 schema::Tensor *scTensor, int preferred_dim, 83 const int tensor_quant_type); 84 std::unique_ptr<schema::CNodeT> CreateCNode(const mindspore::kernel::KernelExec *kernel, 85 std::vector<uint32_t> inputIndex, std::vector<uint32_t> outputIndex, 86 const Model *model); 87 bool IsInputTensor(const schema::TensorT &t); 88 int CreateAndAddCNode(const mindspore::kernel::KernelExec *kernel, std::vector<uint32_t> inputIndex, 89 std::vector<uint32_t> outputIndex, const Model *model); 90 std::unique_ptr<schema::CNodeT> CreateTransformNode(std::vector<uint32_t> inputIndex, 91 std::vector<uint32_t> outputIndex, size_t id); 92 std::unique_ptr<schema::TensorT> CreateTransformTensor(size_t id); 93 std::unique_ptr<schema::TensorT> CreateTransformConst(size_t last_id); 94 int AddTransform(); 95 bool NeedQuantization(const mindspore::lite::Tensor *tensor, const int tensor_quant_type); 96 int FindSchemaTensorByName(const std::vector<uint32_t> &search_indices, const std::string &search_name, 97 size_t *target_index); 98 int KeepGraphInputsInOrder(const Model *model); 99 int ExportTensor(const Model *model, const std::vector<mindspore::lite::Tensor *> &tensors, int offset, 100 const std::vector<mindspore::lite::Tensor *> const_folded_output, 101 const std::vector<std::pair<size_t, tensor_info>> &map_index, 102 const std::vector<std::string> &output_names, const std::set<size_t> &out_set); 103 virtual int QuantTensorData(schema::TensorT *dest_tensor, const mindspore::lite::Tensor *src_tensor, 104 int preferred_dim); 105 mindspore::schema::QuantType GetNodeQuantType(const mindspore::kernel::KernelExec *kernel); 106 void TagQuantizedNodes(); 107 QuantizationType quant_type_; 108 }; 109 }; // namespace lite 110 } // namespace mindspore 111 112 #endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_EXPORT_H_ 113