1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_LITE_SRC_TRAIN_TRANSFER_SESSION_H_ 17 #define MINDSPORE_LITE_SRC_TRAIN_TRANSFER_SESSION_H_ 18 #include <memory> 19 #include <vector> 20 #include <string> 21 #include <tuple> 22 #include <unordered_map> 23 #include <utility> 24 #include "src/litert/lite_session.h" 25 #include "src/train/train_session.h" 26 27 /* 28 Inheritance Diagram 29 30 +-------------------------------+ 31 | session::LiteSession | 32 +--------------↑----------------+ 33 | 34 +--------------+----------------+ 35 | lite::LiteSession | 36 +--------------↑----------------+ 37 | 38 +--------------+----------------+ 39 | lite::TrainSession | 40 +--------------↑----------------+ 41 | 42 +--------------+----------------+ 43 | lite::TrasferSession | 44 +-------------------------------+ 45 */ 46 47 namespace mindspore { 48 namespace lite { 49 50 class TransferSession : public lite::TrainSession { 51 public: 52 explicit TransferSession(const char *model_buf_backbone, size_t size_backbone, 53 const std::shared_ptr<lite::InnerContext> &context); 54 55 ~TransferSession(); 56 is_valid()57 bool is_valid() const { return is_valid_; } 58 59 int RunGraph(const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) override; 60 61 void BindThread(bool if_bind) override; 62 std::vector<lite::Tensor *> GetInputs() const override; 63 mindspore::lite::Tensor *GetInputsByTensorName(const std::string &tensor_name) const override; 64 65 int CompileTransferGraph(); 66 int Export(const std::string &fb_name, ModelType model_type, QuantizationType quant_type, FormatType, 67 std::vector<std::string> out_put_tensor_name = {}) override; 68 int Export(Buffer *model_buffer, ModelType model_type, QuantizationType quant_type, FormatType, 69 std::vector<std::string> out_put_tensor_name = {}) override; 70 71 protected: 72 LiteSession *backbone_session_ = nullptr; 73 char *lite_model_ = nullptr; 74 std::vector<mindspore::lite::Tensor *> combined_inputs_; 75 std::vector<std::pair<mindspore::lite::Tensor *, mindspore::lite::Tensor *>> backbone_head_map_; 76 bool is_valid_ = false; 77 78 private: 79 template <typename DestType> 80 int ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type, FormatType, 81 std::vector<std::string> out_put_tensor_name = {}); 82 bool CompileFormatTransform(lite::Tensor *out, lite::Tensor *in, int *mask, size_t mask_len); 83 std::unordered_map<size_t, size_t> ConnectionMap(); 84 bool nchw2nhwc_ = false; 85 size_t size_backbone_; 86 }; 87 88 lite::LiteSession *CreateTransferSessionInt(const char *model_buf_backbone, size_t size_backbone, 89 const char *model_buf_head, size_t size_head, 90 const std::shared_ptr<InnerContext> &context, bool train_mode, 91 const lite::TrainCfg *cfg); 92 } // namespace lite 93 } // namespace mindspore 94 #endif // MINDSPORE_LITE_SRC_TRAIN_TRANSFER_SESSION_H_ 95