1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "include/api/model.h"
18 #include "src/common/log_adapter.h"
19 #include "src/litert/cxx_api/model/model_impl.h"
20 namespace mindspore {
BuildTransferLearning(GraphCell backbone,GraphCell head,const std::shared_ptr<Context> & context,const std::shared_ptr<TrainCfg> & train_cfg)21 Status Model::BuildTransferLearning(GraphCell backbone, GraphCell head, const std::shared_ptr<Context> &context,
22 const std::shared_ptr<TrainCfg> &train_cfg) {
23 std::stringstream err_msg;
24 if (impl_ == nullptr) {
25 impl_ = std::make_shared<ModelImpl>();
26 if (impl_ == nullptr) {
27 MS_LOG(ERROR) << "Model implement is null.";
28 return kLiteFileError;
29 }
30 }
31
32 if (backbone.GetGraph() == nullptr || head.GetGraph() == nullptr) {
33 err_msg << "Invalid null graph.";
34 MS_LOG(ERROR) << err_msg.str();
35 return Status(kLiteNullptr, err_msg.str());
36 }
37 if (context == nullptr) {
38 err_msg << "Invalid null context.";
39 MS_LOG(ERROR) << err_msg.str();
40 return Status(kLiteNullptr, err_msg.str());
41 }
42 impl_->SetContext(context);
43 impl_->SetGraph(head.GetGraph());
44 impl_->SetConfig(train_cfg);
45
46 Status ret = impl_->BuildTransferLearning(backbone.GetGraph(), head.GetGraph());
47 if (ret != kSuccess) {
48 return ret;
49 }
50 return kSuccess;
51 }
52 } // namespace mindspore
53