• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/litert/cxx_api/model/model_impl.h"
18 #include <unordered_map>
19 #include <algorithm>
20 #include "include/api/serialization.h"
21 #include "include/api/callback/callback.h"
22 #include "include/api/metrics/metrics.h"
23 #include "src/litert/cxx_api/converters.h"
24 #include "src/litert/cxx_api/metrics/metrics_adapter.h"
25 #include "src/litert/cxx_api/metrics/metrics_impl.h"
26 #include "src/litert/cxx_api/callback/callback_adapter.h"
27 #include "src/litert/cxx_api/callback/callback_impl.h"
28 #include "src/common/log_adapter.h"
29 #include "src/train/train_session.h"
30 #include "src/train/transfer_session.h"
31 
32 namespace mindspore {
PrepareMetrics(Model * model,std::vector<session::Metrics * > * out_ms,std::vector<session::Metrics * > * adapter_ms)33 Status ModelImpl::PrepareMetrics(Model *model, std::vector<session::Metrics *> *out_ms,
34                                  std::vector<session::Metrics *> *adapter_ms) {
35   if (out_ms == nullptr || adapter_ms == nullptr) {
36     MS_LOG(ERROR) << "Null input callbacks";
37     return kLiteUninitializedObj;
38   }
39   auto model_metrics = GetMetrics();
40   for (auto m : model_metrics) {
41     if (m == nullptr) {
42       MS_LOG(ERROR) << "Null input metrics";
43       return kLiteUninitializedObj;
44     }
45     if (m->metrics_impl_ != nullptr) {
46       // For off-the-shelf metrics it is guaranteed that we have also an MSLite implementation
47       auto internal_m = m->metrics_impl_->GetInternalMetrics();
48       if (internal_m == nullptr) {
49         MS_LOG(ERROR) << "Internal metric is null.";
50         clearVectorOfPointers(adapter_ms);
51         return kLiteUninitializedObj;
52       }
53       out_ms->push_back(internal_m);
54     } else {
55       // For custom metric we use the metric adapter to mediate between MSLite level to API level
56       auto adapter_m = new (std::nothrow) MetricsAdapter(m);
57       if (adapter_m == nullptr) {  // Error during allocation
58         MS_LOG(ERROR) << "Error during allocation";
59         clearVectorOfPointers(adapter_ms);
60         return kLiteNullptr;
61       }
62       out_ms->push_back(adapter_m);
63       adapter_ms->push_back(adapter_m);
64     }
65   }
66   return kSuccess;
67 }
68 
ConvertCallbacks(Model * model,std::vector<TrainCallBack * > * i_cbs,std::vector<lite::TrainLoopCallBack * > * o_cbs,std::vector<lite::TrainLoopCallBack * > * adapter_cbs)69 Status ModelImpl::ConvertCallbacks(Model *model, std::vector<TrainCallBack *> *i_cbs,
70                                    std::vector<lite::TrainLoopCallBack *> *o_cbs,
71                                    std::vector<lite::TrainLoopCallBack *> *adapter_cbs) {
72   if (i_cbs == nullptr || o_cbs == nullptr || adapter_cbs == nullptr) {
73     MS_LOG(ERROR) << "Null input callbacks";
74     return kLiteUninitializedObj;
75   }
76   for (auto cb : *i_cbs) {
77     if (cb == nullptr) {
78       return kLiteUninitializedObj;
79     }
80     if (cb->callback_impl_ != nullptr) {
81       // For off-the-shelf callback it is guaranteed that we have also an MSLite implementation
82       auto internal_cb = cb->callback_impl_->GetInternalCallback();
83       if (internal_cb == nullptr) {
84         MS_LOG(ERROR) << "Internal callback is null";
85         clearVectorOfPointers(adapter_cbs);
86         return kLiteUninitializedObj;
87       }
88       o_cbs->push_back(internal_cb);
89     } else {
90       // For custom callbacks we use the callback adapter to mediate between MSLite level to API level
91       auto adapter_cb = new (std::nothrow) TrainLoopCallBackAdapter(model, cb);
92       if (adapter_cb == nullptr) {  // Error during allocation
93         MS_LOG(ERROR) << "Error during allocation";
94         clearVectorOfPointers(adapter_cbs);
95         return kLiteNullptr;
96       }
97       o_cbs->push_back(adapter_cb);
98       adapter_cbs->push_back(adapter_cb);
99     }
100   }
101   return kSuccess;
102 }
103 }  // namespace mindspore
104