1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/cxx_api/model/model_impl.h"
18 #include <memory>
19 #include <unordered_map>
20 #include <algorithm>
21 #include "include/api/types.h"
22 #include "include/api/context.h"
23 #include "include/api/dual_abi_helper.h"
24 #include "include/lite_session.h"
25 #include "include/context.h"
26 #include "include/api/callback/callback.h"
27 #include "include/api/metrics/metrics.h"
28 #include "src/lite_model.h"
29 #include "src/runtime/inner_allocator.h"
30 #include "src/cxx_api/converters.h"
31 #include "src/cxx_api/graph/graph_data.h"
32 #include "src/cxx_api/tensor/tensor_impl.h"
33 #include "src/cxx_api/tensor_utils.h"
34 #include "src/cxx_api/metrics/metrics_adapter.h"
35 #include "src/cxx_api/metrics/metrics_impl.h"
36 #include "src/cxx_api/callback/callback_adapter.h"
37 #include "src/cxx_api/callback/callback_impl.h"
38 #include "src/common/log_adapter.h"
39 #include "src/train/train_session.h"
40
41 namespace mindspore {
PrepareMetrics(Model * model,std::vector<session::Metrics * > * out_ms,std::vector<session::Metrics * > * adapter_ms)42 Status ModelImpl::PrepareMetrics(Model *model, std::vector<session::Metrics *> *out_ms,
43 std::vector<session::Metrics *> *adapter_ms) {
44 if (out_ms == nullptr || adapter_ms == nullptr) {
45 MS_LOG(ERROR) << "Null input callbacks";
46 return kLiteUninitializedObj;
47 }
48 auto model_metrics = GetMetrics();
49 for (auto m : model_metrics) {
50 if (m == nullptr) {
51 MS_LOG(ERROR) << "Null input metrics";
52 return kLiteUninitializedObj;
53 }
54 if (m->metrics_impl_) {
55 // For off-the-shelf metrics it is guaranteed that we have also an MSLite implementation
56 auto internal_m = m->metrics_impl_->GetInternalMetrics();
57 if (internal_m == nullptr) {
58 MS_LOG(ERROR) << "Internal metric is null.";
59 clearVectorOfPointers(adapter_ms);
60 return kLiteUninitializedObj;
61 }
62 out_ms->push_back(internal_m);
63 } else {
64 // For custom metric we use the metric adapter to mediate between MSLite level to API level
65 auto adapter_m = new (std::nothrow) MetricsAdapter(m);
66 if (adapter_m == nullptr) { // Error during allocation
67 MS_LOG(ERROR) << "Error during allocation";
68 clearVectorOfPointers(adapter_ms);
69 return kLiteNullptr;
70 }
71 out_ms->push_back(adapter_m);
72 adapter_ms->push_back(adapter_m);
73 }
74 }
75 return kSuccess;
76 }
77
ConvertCallbacks(Model * model,std::vector<TrainCallBack * > * i_cbs,std::vector<session::TrainLoopCallBack * > * o_cbs,std::vector<session::TrainLoopCallBack * > * adapter_cbs)78 Status ModelImpl::ConvertCallbacks(Model *model, std::vector<TrainCallBack *> *i_cbs,
79 std::vector<session::TrainLoopCallBack *> *o_cbs,
80 std::vector<session::TrainLoopCallBack *> *adapter_cbs) {
81 if (i_cbs == nullptr || o_cbs == nullptr || adapter_cbs == nullptr) {
82 MS_LOG(ERROR) << "Null input callbacks";
83 return kLiteUninitializedObj;
84 }
85 for (auto cb : *i_cbs) {
86 if (cb == nullptr) {
87 return kLiteUninitializedObj;
88 }
89 if (cb->callback_impl_) {
90 // For off-the-shelf callback it is guaranteed that we have also an MSLite implementation
91 auto internal_cb = cb->callback_impl_->GetInternalCallback();
92 if (internal_cb == nullptr) {
93 MS_LOG(ERROR) << "Internal callback is null";
94 clearVectorOfPointers(adapter_cbs);
95 return kLiteUninitializedObj;
96 }
97 o_cbs->push_back(internal_cb);
98 } else {
99 // For custom callbacks we use the callback adapter to mediate between MSLite level to API level
100 auto adapter_cb = new (std::nothrow) TrainLoopCallBackAdapter(model, cb);
101 if (adapter_cb == nullptr) { // Error during allocation
102 MS_LOG(ERROR) << "Error during allocation";
103 clearVectorOfPointers(adapter_cbs);
104 return kLiteNullptr;
105 }
106 o_cbs->push_back(adapter_cb);
107 adapter_cbs->push_back(adapter_cb);
108 }
109 }
110 return kSuccess;
111 }
112 } // namespace mindspore
113