• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/device/kernel_info.h"
18 
19 namespace mindspore {
20 namespace device {
select_kernel_build_info() const21 const kernel::KernelBuildInfo *KernelInfo::select_kernel_build_info() const { return select_kernel_build_info_.get(); }
22 
GetMutableSelectKernelBuildInfo() const23 kernel::KernelBuildInfoPtr KernelInfo::GetMutableSelectKernelBuildInfo() const { return select_kernel_build_info_; }
24 
GetOutputAddr(size_t index) const25 const DeviceAddress *KernelInfo::GetOutputAddr(size_t index) const {
26   if (index >= output_address_list_.size()) {
27     MS_LOG(ERROR) << "Index [" << index << "] out of range 0~" << (output_address_list_.size() - 1);
28     return nullptr;
29   }
30   return output_address_list_[index].get();
31 }
32 
GetMutableOutputAddr(size_t index) const33 DeviceAddressPtr KernelInfo::GetMutableOutputAddr(size_t index) const {
34   if (index >= output_address_list_.size()) {
35     MS_LOG(ERROR) << "Index [" << index << "] out of range";
36     return nullptr;
37   }
38   return output_address_list_[index];
39 }
40 
OutputAddrExist(size_t index) const41 bool KernelInfo::OutputAddrExist(size_t index) const {
42   if (index >= output_address_list_.size()) {
43     return false;
44   }
45   return output_address_list_[index] != nullptr;
46 }
47 
SetOutputAddr(const DeviceAddressPtr & output_address,size_t index)48 bool KernelInfo::SetOutputAddr(const DeviceAddressPtr &output_address, size_t index) {
49   // parameter and valuenode
50   if (kernel_mod_ == nullptr && index >= output_address_list_.size()) {
51     for (size_t i = output_address_list_.size(); i <= index; i++) {
52       output_address_list_.emplace_back(nullptr);
53     }
54   } else if (kernel_mod_ != nullptr && output_address_list_.empty()) {
55     // set cnode
56     for (size_t i = 0; i < kernel_mod_->GetOutputSizeList().size(); i++) {
57       output_address_list_.emplace_back(nullptr);
58     }
59   }
60   if (index >= output_address_list_.size()) {
61     MS_LOG(ERROR) << "Index [" << index << "] out of range";
62     return false;
63   }
64   output_address_list_[index] = output_address;
65   return true;
66 }
67 
GetWorkspaceAddr(size_t index) const68 DeviceAddress *KernelInfo::GetWorkspaceAddr(size_t index) const {
69   if (index >= workspace_address_list_.size()) {
70     MS_LOG(ERROR) << "Index [" << index << "] out of range";
71     return nullptr;
72   }
73   return workspace_address_list_[index].get();
74 }
75 
GetMutableWorkspaceAddr(size_t index) const76 DeviceAddressPtr KernelInfo::GetMutableWorkspaceAddr(size_t index) const {
77   if (index >= workspace_address_list_.size()) {
78     MS_LOG(ERROR) << "Index [" << index << "] out of range";
79     return nullptr;
80   }
81   return workspace_address_list_[index];
82 }
83 
WorkspaceAddrExist(size_t index) const84 bool KernelInfo::WorkspaceAddrExist(size_t index) const {
85   if (index >= workspace_address_list_.size()) {
86     return false;
87   }
88   return workspace_address_list_[index] != nullptr;
89 }
90 
SetWorkspaceAddr(const DeviceAddressPtr & output_address,size_t index)91 bool KernelInfo::SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index) {
92   if (workspace_address_list_.empty()) {
93     // parameter and valuenode
94     if (kernel_mod_ == nullptr) {
95       workspace_address_list_.emplace_back(nullptr);
96     } else {
97       // set cnode
98       for (size_t i = 0; i < kernel_mod_->GetWorkspaceSizeList().size(); i++) {
99         workspace_address_list_.emplace_back(nullptr);
100       }
101     }
102   }
103   if (index >= workspace_address_list_.size()) {
104     MS_LOG(ERROR) << "Index [" << index << "] out of range";
105     return false;
106   }
107   workspace_address_list_[index] = output_address;
108   return true;
109 }
110 
set_kernel_mod(const kernel::KernelModPtr & kernel_mod)111 void KernelInfo::set_kernel_mod(const kernel::KernelModPtr &kernel_mod) { kernel_mod_ = kernel_mod; }
112 
MutableKernelMod() const113 kernel::KernelMod *KernelInfo::MutableKernelMod() const { return kernel_mod_.get(); }
114 
kernel_mod() const115 const kernel::KernelMod *KernelInfo::kernel_mod() const { return kernel_mod_.get(); }
116 
operator ==(const KernelInfo & other) const117 bool KernelInfo::operator==(const KernelInfo &other) const {
118   if (stream_id_ != other.stream_id_ || stream_distinction_label_ != other.stream_distinction_label_ ||
119       graph_id_ != other.graph_id_) {
120     return false;
121   }
122   if ((select_kernel_build_info_ != nullptr && other.select_kernel_build_info_ == nullptr) ||
123       (select_kernel_build_info_ == nullptr && other.select_kernel_build_info_ != nullptr)) {
124     return false;
125   }
126   if (select_kernel_build_info_ != nullptr && other.select_kernel_build_info_ != nullptr) {
127     if (!(*select_kernel_build_info_ == *(other.select_kernel_build_info_))) {
128       return false;
129     }
130   }
131   // Currently we only check whether both the kernel_mod_ are initialized or uninitialized.
132   if ((kernel_mod_ == nullptr && other.kernel_mod_ != nullptr) ||
133       (kernel_mod_ != nullptr && other.kernel_mod_ == nullptr)) {
134     return false;
135   }
136   // Currently we only check whether both the sizes are equal of output_address_list_ and workspace_address_list_ or
137   // not. We can complete this check in the future.
138   if (output_address_list_.size() != other.output_address_list_.size() ||
139       workspace_address_list_.size() != other.workspace_address_list_.size()) {
140     return false;
141   }
142   return true;
143 }
144 }  // namespace device
145 }  // namespace mindspore
146