1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_LITE_SRC_TRAIN_TRAIN_EXPORT_H_ 17 #define MINDSPORE_LITE_SRC_TRAIN_TRAIN_EXPORT_H_ 18 #include <string> 19 #include <vector> 20 #include <memory> 21 #include <map> 22 #include <unordered_map> 23 #include "schema/inner/model_generated.h" 24 #include "src/lite_kernel.h" 25 #include "src/lite_model.h" 26 #include "include/train/train_cfg.h" 27 28 namespace mindspore { 29 #ifndef _STUB 30 namespace schema { 31 struct CNodeT; 32 struct TensorT; 33 struct MetaGraphT; 34 } // namespace schema 35 #endif 36 namespace lite { 37 38 class TrainExport { 39 public: TrainExport(const std::string file_name)40 explicit TrainExport(const std::string file_name) : file_name_(file_name) {} 41 virtual ~TrainExport(); 42 int ExportNet(const std::vector<mindspore::kernel::LiteKernel *> &kernels, 43 const std::vector<mindspore::lite::Tensor *> &tensors, const std::vector<std::string> &output_names, 44 const Model *model, QuantizationType quant_type); 45 int ExportInit(const std::string model_name, std::string version); 46 int SaveToFile(); set_connect(const std::unordered_map<size_t,size_t> & map)47 void set_connect(const std::unordered_map<size_t, size_t> &map) { connect_ = map; } 48 int LoadModel(void *buf, size_t buf_size); 49 int AddTransformNode(); 50 51 protected: 52 virtual std::vector<uint8_t> CreateData(const mindspore::lite::Tensor *tensor); 53 54 private: 55 std::string file_name_; 56 schema::MetaGraphT *meta_graph_ = nullptr; 57 std::vector<size_t> out_idx_; 58 std::map<size_t, size_t> remap_; 59 std::unordered_map<size_t, size_t> connect_; // connection map (backbone tenor id-> head tensor id) 60 bool IsNodeNonDepend(const std::unique_ptr<schema::CNodeT> &node, const std::vector<size_t> &sinked_tensor_idxes); 61 int TopologicalSort(); 62 void PrepareRemap(int offset); 63 Model::Node *FindNode(const mindspore::kernel::LiteKernel *kernel, const Model *model); 64 std::unique_ptr<schema::TensorT> CreateTensor(const Tensor *tensor, schema::Tensor *scTensor); 65 std::unique_ptr<schema::CNodeT> CreateCNode(const mindspore::kernel::LiteKernel *kernel, 66 std::vector<uint32_t> inputIndex, std::vector<uint32_t> outputIndex, 67 const Model *model); 68 int IsInputTensor(const schema::TensorT &t); 69 int CreateAndAddCNode(const mindspore::kernel::LiteKernel *kernel, std::vector<uint32_t> inputIndex, 70 std::vector<uint32_t> outputIndex, const Model *model); 71 std::unique_ptr<schema::CNodeT> CreateTransformNode(std::vector<uint32_t> inputIndex, 72 std::vector<uint32_t> outputIndex, size_t id); 73 std::unique_ptr<schema::TensorT> CreateTransformTensor(size_t id); 74 std::unique_ptr<schema::TensorT> CreateTransformConst(size_t last_id); 75 int AddTransform(); 76 bool NeedQuantization(const mindspore::lite::Tensor *tensor); 77 virtual int QuantTensorData(schema::TensorT *dest_tensor, const mindspore::lite::Tensor *src_tensor); 78 mindspore::schema::QuantType GetNodeQuantType(const mindspore::kernel::LiteKernel *kernel); 79 void TagQuantizedNodes(); 80 QuantizationType quant_type_; 81 }; 82 }; // namespace lite 83 } // namespace mindspore 84 85 #endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_EXPORT_H_ 86