• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/cxx_api/model/model_impl.h"
18 #include <memory>
19 #include <unordered_map>
20 #include <algorithm>
21 #include "include/api/types.h"
22 #include "include/api/context.h"
23 #include "include/api/dual_abi_helper.h"
24 #include "include/lite_session.h"
25 #include "include/context.h"
26 #include "include/api/callback/callback.h"
27 #include "include/api/metrics/metrics.h"
28 #include "src/lite_model.h"
29 #include "src/runtime/inner_allocator.h"
30 #include "src/cxx_api/converters.h"
31 #include "src/cxx_api/graph/graph_data.h"
32 #include "src/cxx_api/tensor/tensor_impl.h"
33 #include "src/cxx_api/tensor_utils.h"
34 #include "src/cxx_api/metrics/metrics_adapter.h"
35 #include "src/cxx_api/metrics/metrics_impl.h"
36 #include "src/cxx_api/callback/callback_adapter.h"
37 #include "src/cxx_api/callback/callback_impl.h"
38 #include "src/common/log_adapter.h"
39 #include "src/train/train_session.h"
40 
41 namespace mindspore {
PrepareMetrics(Model * model,std::vector<session::Metrics * > * out_ms,std::vector<session::Metrics * > * adapter_ms)42 Status ModelImpl::PrepareMetrics(Model *model, std::vector<session::Metrics *> *out_ms,
43                                  std::vector<session::Metrics *> *adapter_ms) {
44   if (out_ms == nullptr || adapter_ms == nullptr) {
45     MS_LOG(ERROR) << "Null input callbacks";
46     return kLiteUninitializedObj;
47   }
48   auto model_metrics = GetMetrics();
49   for (auto m : model_metrics) {
50     if (m == nullptr) {
51       MS_LOG(ERROR) << "Null input metrics";
52       return kLiteUninitializedObj;
53     }
54     if (m->metrics_impl_) {
55       // For off-the-shelf metrics it is guaranteed that we have also an MSLite implementation
56       auto internal_m = m->metrics_impl_->GetInternalMetrics();
57       if (internal_m == nullptr) {
58         MS_LOG(ERROR) << "Internal metric is null.";
59         clearVectorOfPointers(adapter_ms);
60         return kLiteUninitializedObj;
61       }
62       out_ms->push_back(internal_m);
63     } else {
64       // For custom metric we use the metric adapter to mediate between MSLite level to API level
65       auto adapter_m = new (std::nothrow) MetricsAdapter(m);
66       if (adapter_m == nullptr) {  // Error during allocation
67         MS_LOG(ERROR) << "Error during allocation";
68         clearVectorOfPointers(adapter_ms);
69         return kLiteNullptr;
70       }
71       out_ms->push_back(adapter_m);
72       adapter_ms->push_back(adapter_m);
73     }
74   }
75   return kSuccess;
76 }
77 
ConvertCallbacks(Model * model,std::vector<TrainCallBack * > * i_cbs,std::vector<session::TrainLoopCallBack * > * o_cbs,std::vector<session::TrainLoopCallBack * > * adapter_cbs)78 Status ModelImpl::ConvertCallbacks(Model *model, std::vector<TrainCallBack *> *i_cbs,
79                                    std::vector<session::TrainLoopCallBack *> *o_cbs,
80                                    std::vector<session::TrainLoopCallBack *> *adapter_cbs) {
81   if (i_cbs == nullptr || o_cbs == nullptr || adapter_cbs == nullptr) {
82     MS_LOG(ERROR) << "Null input callbacks";
83     return kLiteUninitializedObj;
84   }
85   for (auto cb : *i_cbs) {
86     if (cb == nullptr) {
87       return kLiteUninitializedObj;
88     }
89     if (cb->callback_impl_) {
90       // For off-the-shelf callback it is guaranteed that we have also an MSLite implementation
91       auto internal_cb = cb->callback_impl_->GetInternalCallback();
92       if (internal_cb == nullptr) {
93         MS_LOG(ERROR) << "Internal callback is null";
94         clearVectorOfPointers(adapter_cbs);
95         return kLiteUninitializedObj;
96       }
97       o_cbs->push_back(internal_cb);
98     } else {
99       // For custom callbacks we use the callback adapter to mediate between MSLite level to API level
100       auto adapter_cb = new (std::nothrow) TrainLoopCallBackAdapter(model, cb);
101       if (adapter_cb == nullptr) {  // Error during allocation
102         MS_LOG(ERROR) << "Error during allocation";
103         clearVectorOfPointers(adapter_cbs);
104         return kLiteNullptr;
105       }
106       o_cbs->push_back(adapter_cb);
107       adapter_cbs->push_back(adapter_cb);
108     }
109   }
110   return kSuccess;
111 }
112 }  // namespace mindspore
113