• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/api/model.h"
18 #include "src/common/log_adapter.h"
19 #include "src/litert/cxx_api/model/model_impl.h"
20 namespace mindspore {
BuildTransferLearning(GraphCell backbone,GraphCell head,const std::shared_ptr<Context> & context,const std::shared_ptr<TrainCfg> & train_cfg)21 Status Model::BuildTransferLearning(GraphCell backbone, GraphCell head, const std::shared_ptr<Context> &context,
22                                     const std::shared_ptr<TrainCfg> &train_cfg) {
23   std::stringstream err_msg;
24   if (impl_ == nullptr) {
25     impl_ = std::make_shared<ModelImpl>();
26     if (impl_ == nullptr) {
27       MS_LOG(ERROR) << "Model implement is null.";
28       return kLiteFileError;
29     }
30   }
31 
32   if (backbone.GetGraph() == nullptr || head.GetGraph() == nullptr) {
33     err_msg << "Invalid null graph.";
34     MS_LOG(ERROR) << err_msg.str();
35     return Status(kLiteNullptr, err_msg.str());
36   }
37   if (context == nullptr) {
38     err_msg << "Invalid null context.";
39     MS_LOG(ERROR) << err_msg.str();
40     return Status(kLiteNullptr, err_msg.str());
41   }
42   impl_->SetContext(context);
43   impl_->SetGraph(head.GetGraph());
44   impl_->SetConfig(train_cfg);
45 
46   Status ret = impl_->BuildTransferLearning(backbone.GetGraph(), head.GetGraph());
47   if (ret != kSuccess) {
48     return ret;
49   }
50   return kSuccess;
51 }
52 }  // namespace mindspore
53