1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/litert/cxx_api/model/model_impl.h"
18 #include <unordered_map>
19 #include <algorithm>
20 #include "include/api/serialization.h"
21 #include "include/api/callback/callback.h"
22 #include "include/api/metrics/metrics.h"
23 #include "src/litert/cxx_api/converters.h"
24 #include "src/litert/cxx_api/metrics/metrics_adapter.h"
25 #include "src/litert/cxx_api/metrics/metrics_impl.h"
26 #include "src/litert/cxx_api/callback/callback_adapter.h"
27 #include "src/litert/cxx_api/callback/callback_impl.h"
28 #include "src/common/log_adapter.h"
29 #include "src/train/train_session.h"
30 #include "src/train/transfer_session.h"
31
32 namespace mindspore {
PrepareMetrics(Model * model,std::vector<session::Metrics * > * out_ms,std::vector<session::Metrics * > * adapter_ms)33 Status ModelImpl::PrepareMetrics(Model *model, std::vector<session::Metrics *> *out_ms,
34 std::vector<session::Metrics *> *adapter_ms) {
35 if (out_ms == nullptr || adapter_ms == nullptr) {
36 MS_LOG(ERROR) << "Null input callbacks";
37 return kLiteUninitializedObj;
38 }
39 auto model_metrics = GetMetrics();
40 for (auto m : model_metrics) {
41 if (m == nullptr) {
42 MS_LOG(ERROR) << "Null input metrics";
43 return kLiteUninitializedObj;
44 }
45 if (m->metrics_impl_ != nullptr) {
46 // For off-the-shelf metrics it is guaranteed that we have also an MSLite implementation
47 auto internal_m = m->metrics_impl_->GetInternalMetrics();
48 if (internal_m == nullptr) {
49 MS_LOG(ERROR) << "Internal metric is null.";
50 clearVectorOfPointers(adapter_ms);
51 return kLiteUninitializedObj;
52 }
53 out_ms->push_back(internal_m);
54 } else {
55 // For custom metric we use the metric adapter to mediate between MSLite level to API level
56 auto adapter_m = new (std::nothrow) MetricsAdapter(m);
57 if (adapter_m == nullptr) { // Error during allocation
58 MS_LOG(ERROR) << "Error during allocation";
59 clearVectorOfPointers(adapter_ms);
60 return kLiteNullptr;
61 }
62 out_ms->push_back(adapter_m);
63 adapter_ms->push_back(adapter_m);
64 }
65 }
66 return kSuccess;
67 }
68
ConvertCallbacks(Model * model,std::vector<TrainCallBack * > * i_cbs,std::vector<lite::TrainLoopCallBack * > * o_cbs,std::vector<lite::TrainLoopCallBack * > * adapter_cbs)69 Status ModelImpl::ConvertCallbacks(Model *model, std::vector<TrainCallBack *> *i_cbs,
70 std::vector<lite::TrainLoopCallBack *> *o_cbs,
71 std::vector<lite::TrainLoopCallBack *> *adapter_cbs) {
72 if (i_cbs == nullptr || o_cbs == nullptr || adapter_cbs == nullptr) {
73 MS_LOG(ERROR) << "Null input callbacks";
74 return kLiteUninitializedObj;
75 }
76 for (auto cb : *i_cbs) {
77 if (cb == nullptr) {
78 return kLiteUninitializedObj;
79 }
80 if (cb->callback_impl_ != nullptr) {
81 // For off-the-shelf callback it is guaranteed that we have also an MSLite implementation
82 auto internal_cb = cb->callback_impl_->GetInternalCallback();
83 if (internal_cb == nullptr) {
84 MS_LOG(ERROR) << "Internal callback is null";
85 clearVectorOfPointers(adapter_cbs);
86 return kLiteUninitializedObj;
87 }
88 o_cbs->push_back(internal_cb);
89 } else {
90 // For custom callbacks we use the callback adapter to mediate between MSLite level to API level
91 auto adapter_cb = new (std::nothrow) TrainLoopCallBackAdapter(model, cb);
92 if (adapter_cb == nullptr) { // Error during allocation
93 MS_LOG(ERROR) << "Error during allocation";
94 clearVectorOfPointers(adapter_cbs);
95 return kLiteNullptr;
96 }
97 o_cbs->push_back(adapter_cb);
98 adapter_cbs->push_back(adapter_cb);
99 }
100 }
101 return kSuccess;
102 }
103 } // namespace mindspore
104