1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "include/backend/kernel_info.h"
18
19 namespace mindspore {
20 namespace device {
select_kernel_build_info() const21 const kernel::KernelBuildInfo *KernelInfo::select_kernel_build_info() const { return select_kernel_build_info_.get(); }
22
GetMutableSelectKernelBuildInfo() const23 kernel::KernelBuildInfoPtr KernelInfo::GetMutableSelectKernelBuildInfo() const { return select_kernel_build_info_; }
24
GetOutputKernelTensor(size_t index) const25 const KernelTensorPtr &KernelInfo::GetOutputKernelTensor(size_t index) const {
26 if (index >= output_kernel_tensor_list_.size()) {
27 MS_LOG(EXCEPTION) << "Index [" << index << "] out of range 0~" << (output_kernel_tensor_list_.size() - 1);
28 }
29 return output_kernel_tensor_list_[index];
30 }
31
SetOutputKernelTensor(const KernelTensorPtr & kernel_tensor,size_t index)32 bool KernelInfo::SetOutputKernelTensor(const KernelTensorPtr &kernel_tensor, size_t index) {
33 // Initialize empty output kernel tensor list for Parameter and ValueNode.
34 if (kernel_mod_ == nullptr && index >= output_kernel_tensor_list_.size()) {
35 for (size_t i = output_kernel_tensor_list_.size(); i <= index; i++) {
36 (void)output_kernel_tensor_list_.emplace_back(nullptr);
37 }
38 } else if (kernel_mod_ != nullptr && output_kernel_tensor_list_.empty()) {
39 // Initialize empty output kernel tensor list for CNode.
40 MS_EXCEPTION_IF_NULL(select_kernel_build_info_);
41 for (size_t i = 0; i < select_kernel_build_info_->GetOutputNum(); i++) {
42 (void)output_kernel_tensor_list_.emplace_back(nullptr);
43 }
44 }
45
46 if (index >= output_kernel_tensor_list_.size()) {
47 MS_LOG(ERROR) << "Index [" << index << "] out of range";
48 return false;
49 }
50
51 output_kernel_tensor_list_[index] = kernel_tensor;
52 return true;
53 }
54
OutputKernelTensorExist(size_t index) const55 bool KernelInfo::OutputKernelTensorExist(size_t index) const {
56 if (index >= output_kernel_tensor_list_.size()) {
57 return false;
58 }
59 return output_kernel_tensor_list_[index] != nullptr;
60 }
61
GetOutputAddr(size_t index) const62 const DeviceAddress *KernelInfo::GetOutputAddr(size_t index) const {
63 if (index >= output_address_list_.size()) {
64 MS_LOG(ERROR) << "Index [" << index << "] out of range 0~" << (output_address_list_.size() - 1);
65 return nullptr;
66 }
67 return output_address_list_[index].get();
68 }
69
GetMutableOutputAddr(size_t index) const70 DeviceAddressPtr KernelInfo::GetMutableOutputAddr(size_t index) const {
71 if (index >= output_address_list_.size()) {
72 MS_LOG(ERROR) << "Index [" << index << "] out of range";
73 return nullptr;
74 }
75 return output_address_list_[index];
76 }
77
OutputAddrExist(size_t index) const78 bool KernelInfo::OutputAddrExist(size_t index) const {
79 if (index >= output_address_list_.size()) {
80 return false;
81 }
82 return output_address_list_[index] != nullptr;
83 }
84
SetOutputAddr(const DeviceAddressPtr & output_address,size_t index)85 bool KernelInfo::SetOutputAddr(const DeviceAddressPtr &output_address, size_t index) {
86 // parameter and valuenode
87 if (kernel_mod_ == nullptr && index >= output_address_list_.size()) {
88 for (size_t i = output_address_list_.size(); i <= index; i++) {
89 (void)output_address_list_.emplace_back(nullptr);
90 }
91 } else if (kernel_mod_ != nullptr && output_address_list_.empty()) {
92 // set cnode
93 MS_EXCEPTION_IF_NULL(select_kernel_build_info_);
94 for (size_t i = 0; i < select_kernel_build_info_->GetOutputNum(); i++) {
95 (void)output_address_list_.emplace_back(nullptr);
96 }
97 }
98 if (index >= output_address_list_.size()) {
99 MS_LOG(ERROR) << "Index [" << index << "] out of range";
100 return false;
101 }
102 output_address_list_[index] = output_address;
103
104 return true;
105 }
106
GetWorkspaceAddr(size_t index) const107 DeviceAddress *KernelInfo::GetWorkspaceAddr(size_t index) const {
108 if (index >= workspace_address_list_.size()) {
109 MS_LOG(ERROR) << "Index [" << index << "] out of range";
110 return nullptr;
111 }
112 return workspace_address_list_[index].get();
113 }
114
GetMutableWorkspaceAddr(size_t index) const115 DeviceAddressPtr KernelInfo::GetMutableWorkspaceAddr(size_t index) const {
116 if (index >= workspace_address_list_.size()) {
117 MS_LOG(ERROR) << "Index [" << index << "] out of range";
118 return nullptr;
119 }
120 return workspace_address_list_[index];
121 }
122
WorkspaceAddrExist(size_t index) const123 bool KernelInfo::WorkspaceAddrExist(size_t index) const {
124 if (index >= workspace_address_list_.size()) {
125 return false;
126 }
127 return workspace_address_list_[index] != nullptr;
128 }
129
SetWorkspaceAddr(const DeviceAddressPtr & output_address,size_t index)130 bool KernelInfo::SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index) {
131 if (workspace_address_list_.empty()) {
132 // parameter and valuenode
133 if (kernel_mod_ == nullptr) {
134 (void)workspace_address_list_.emplace_back(nullptr);
135 } else {
136 // set cnode
137 for (size_t i = 0; i < kernel_mod_->GetWorkspaceSizeList().size(); i++) {
138 (void)workspace_address_list_.emplace_back(nullptr);
139 }
140 }
141 }
142 if (index >= workspace_address_list_.size()) {
143 MS_LOG(ERROR) << "Index [" << index << "] out of range";
144 return false;
145 }
146 workspace_address_list_[index] = output_address;
147 return true;
148 }
149
SetSomasResult(std::vector<std::pair<size_t,size_t>> && output_somas_result,std::vector<std::pair<size_t,size_t>> && workspace_somas_result)150 bool KernelInfo::SetSomasResult(std::vector<std::pair<size_t, size_t>> &&output_somas_result,
151 std::vector<std::pair<size_t, size_t>> &&workspace_somas_result) {
152 somas_output_result_ = std::move(output_somas_result);
153 somas_workspace_result_ = std::move(workspace_somas_result);
154 return true;
155 }
156
GetTensorSomasOffset(const std::vector<std::pair<size_t,size_t>> & somas_result,size_t tensor_index) const157 size_t KernelInfo::GetTensorSomasOffset(const std::vector<std::pair<size_t, size_t>> &somas_result,
158 size_t tensor_index) const {
159 if (somas_result.empty()) {
160 return 0;
161 }
162 if (tensor_index >= somas_result.size()) {
163 MS_LOG(EXCEPTION) << "The tensor index:" << tensor_index << " is out of range:" << somas_result.size();
164 }
165 return somas_result[tensor_index].first;
166 }
167
GetTensorSomasAlignedSize(const std::vector<std::pair<size_t,size_t>> & somas_result,size_t tensor_index) const168 size_t KernelInfo::GetTensorSomasAlignedSize(const std::vector<std::pair<size_t, size_t>> &somas_result,
169 size_t tensor_index) const {
170 if (somas_result.empty()) {
171 return 0;
172 }
173 if (tensor_index >= somas_result.size()) {
174 MS_LOG(EXCEPTION) << "The tensor index:" << tensor_index << " is out of range:" << somas_result.size();
175 }
176 return somas_result[tensor_index].second;
177 }
178
IsTensorEnableSomas(const std::vector<std::pair<size_t,size_t>> & somas_result,size_t tensor_index) const179 bool KernelInfo::IsTensorEnableSomas(const std::vector<std::pair<size_t, size_t>> &somas_result,
180 size_t tensor_index) const {
181 if (somas_result.empty()) {
182 return false;
183 }
184 if (tensor_index >= somas_result.size()) {
185 MS_LOG(EXCEPTION) << "The tensor index:" << tensor_index << " is out of range:" << somas_result.size();
186 }
187 return (somas_result[tensor_index].second != 0);
188 }
189
set_kernel_mod(const kernel::KernelModPtr & kernel_mod)190 void KernelInfo::set_kernel_mod(const kernel::KernelModPtr &kernel_mod) { kernel_mod_ = kernel_mod; }
191
MutableKernelMod() const192 kernel::KernelMod *KernelInfo::MutableKernelMod() const { return kernel_mod_.get(); }
193
GetKernelMod() const194 kernel::KernelModPtr KernelInfo::GetKernelMod() const { return kernel_mod_; }
195
kernel_mod() const196 const kernel::KernelMod *KernelInfo::kernel_mod() const { return kernel_mod_.get(); }
197
operator ==(const KernelInfo & other) const198 bool KernelInfo::operator==(const KernelInfo &other) const {
199 if (stream_id_ != other.stream_id_ || stream_distinction_label_ != other.stream_distinction_label_ ||
200 graph_id_ != other.graph_id_) {
201 return false;
202 }
203 if ((select_kernel_build_info_ != nullptr && other.select_kernel_build_info_ == nullptr) ||
204 (select_kernel_build_info_ == nullptr && other.select_kernel_build_info_ != nullptr)) {
205 return false;
206 }
207 if (select_kernel_build_info_ != nullptr && other.select_kernel_build_info_ != nullptr) {
208 if (!(*select_kernel_build_info_ == *(other.select_kernel_build_info_))) {
209 return false;
210 }
211 }
212 // Currently we only check whether both the kernel_mod_ are initialized or uninitialized.
213 if ((kernel_mod_ == nullptr && other.kernel_mod_ != nullptr) ||
214 (kernel_mod_ != nullptr && other.kernel_mod_ == nullptr)) {
215 return false;
216 }
217 // Currently we only check whether both the sizes are equal of output_address_list_ and workspace_address_list_ or
218 // not. We can complete this check in the future.
219 if (output_address_list_.size() != other.output_address_list_.size() ||
220 workspace_address_list_.size() != other.workspace_address_list_.size()) {
221 return false;
222 }
223 return true;
224 }
225
set_ref_map(const bool & all_ref,const OutputInputRefMap & ref_map)226 void KernelInfo::set_ref_map(const bool &all_ref, const OutputInputRefMap &ref_map) {
227 if (all_ref) {
228 MS_EXCEPTION_IF_NULL(select_kernel_build_info_);
229 for (size_t i = 0; i < select_kernel_build_info_->GetInputNum(); i++) {
230 out_in_ref_map_[i] = i;
231 }
232 } else {
233 out_in_ref_map_ = ref_map;
234 }
235 }
236 } // namespace device
237 } // namespace mindspore
238