1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/pynative/op_runtime_info.h"
18
19 #include <utility>
20
21 #include "include/backend/anf_runtime_algorithm.h"
22 #include "include/common/utils/anfalgo.h"
23 #include "runtime/device/ms_device_shape_transfer.h"
24
25 namespace mindspore::runtime {
26 namespace {
OpRuntimeInfoGetOutputTensorMemSize(const AnfNodePtr & node,size_t output_index,TypeId type,const std::string & format,const ShapeVector & device_shape)27 size_t OpRuntimeInfoGetOutputTensorMemSize(const AnfNodePtr &node, size_t output_index, TypeId type,
28 const std::string &format, const ShapeVector &device_shape) {
29 MS_EXCEPTION_IF_NULL(node);
30 if (output_index >= AnfAlgo::GetOutputTensorNum(node)) {
31 MS_EXCEPTION(ArgumentError) << "output index [" << output_index << "] large than the output size ["
32 << AnfAlgo::GetOutputTensorNum(node) << "] of node!";
33 }
34 size_t type_size = GetTypeByte(TypeIdToType(type));
35 auto shape = device_shape;
36 if (IsDynamic(shape)) {
37 auto max_shape = common::AnfAlgo::GetOutputMaxShape(node, output_index);
38 if (!max_shape.empty()) {
39 shape = max_shape;
40 MS_LOG(DEBUG) << "shape[" << shape << "] is dynamic, using max_shape[" << max_shape << "] instead.";
41 } else {
42 shape = {1};
43 MS_LOG(DEBUG) << "shape[" << shape << "] is dynamic, set default to {1}";
44 }
45 }
46 if (shape.empty() && format != kOpFormat_DEFAULT) {
47 shape = trans::PaddingShape(shape, format, AnfAlgo::GetOutputReshapeType(node, output_index), node);
48 shape = trans::TransShapeToDevice(shape, format, node, output_index, type);
49 }
50 // scalar's output shape is a empty vector
51 size_t tensor_size = type_size * SizeOf(shape);
52 return tensor_size;
53 }
54
CacheForExecutionOrder(const KernelGraphPtr & graph)55 void CacheForExecutionOrder(const KernelGraphPtr &graph) {
56 MS_EXCEPTION_IF_NULL(graph);
57 const auto &nodes = graph->execution_order();
58 for (auto const &node : nodes) {
59 std::vector<std::string> formats;
60 std::vector<TypeId> types;
61 std::vector<size_t> tensor_sizes;
62 std::vector<ShapeVector> output_infer_shape;
63 std::vector<ShapeVector> output_device_shape;
64 auto output_num = AnfAlgo::GetOutputTensorNum(node);
65 for (size_t i = 0; i < output_num; ++i) {
66 std::string output_format = AnfAlgo::GetOutputFormat(node, i);
67 auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i);
68 auto device_shape = AnfAlgo::GetOutputDeviceShape(node, i);
69 auto tensor_size = OpRuntimeInfoGetOutputTensorMemSize(node, i, output_type, output_format, device_shape);
70 (void)formats.emplace_back(output_format);
71 (void)types.emplace_back(output_type);
72 (void)tensor_sizes.emplace_back(tensor_size);
73 (void)output_infer_shape.emplace_back(common::AnfAlgo::GetOutputInferShape(node, i));
74 (void)output_device_shape.emplace_back(device_shape);
75 }
76
77 // For input
78 std::vector<std::pair<device::KernelInfo *, size_t>> input_kernel_infos;
79 auto input_size = common::AnfAlgo::GetInputTensorNum(node);
80 for (size_t i = 0; i < input_size; ++i) {
81 session::KernelWithIndex kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(node, i, true);
82 MS_EXCEPTION_IF_NULL(kernel_with_index.first);
83 (void)input_kernel_infos.emplace_back(dynamic_cast<device::KernelInfo *>(kernel_with_index.first->kernel_info()),
84 kernel_with_index.second);
85 }
86
87 // For workspace and output
88 MS_EXCEPTION_IF_NULL(node);
89 auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
90
91 node->set_user_data<runtime::OpRuntimeInfo>(std::make_shared<runtime::OpRuntimeInfo>(
92 formats, types, tensor_sizes, output_infer_shape, output_device_shape, kernel_info, input_kernel_infos));
93 }
94 }
95
CacheForGraphInputs(const KernelGraphPtr & graph)96 void CacheForGraphInputs(const KernelGraphPtr &graph) {
97 MS_EXCEPTION_IF_NULL(graph);
98 const auto &inputs = graph->inputs();
99 for (const auto &input : inputs) {
100 MS_EXCEPTION_IF_NULL(input);
101 if (!input->isa<Parameter>()) {
102 continue;
103 }
104 std::vector<std::string> formats;
105 std::vector<TypeId> types;
106 std::vector<size_t> tensor_sizes;
107 std::vector<ShapeVector> output_infer_shape;
108 std::vector<ShapeVector> output_device_shape;
109 auto output_size = AnfAlgo::GetOutputTensorNum(input);
110 for (size_t index = 0; index < output_size; index++) {
111 auto format = AnfAlgo::GetOutputFormat(input, index);
112 auto type_id = AnfAlgo::GetOutputDeviceDataType(input, index);
113 if (type_id == kTypeUnknown) {
114 type_id = common::AnfAlgo::GetOutputInferDataType(input, index);
115 }
116 auto device_shape = AnfAlgo::GetOutputDeviceShape(input, index);
117 auto tensor_size = OpRuntimeInfoGetOutputTensorMemSize(input, index, type_id, format, device_shape);
118 (void)formats.emplace_back(format);
119 (void)types.emplace_back(type_id);
120 (void)tensor_sizes.emplace_back(tensor_size);
121 (void)output_infer_shape.emplace_back(common::AnfAlgo::GetOutputInferShape(input, index));
122 (void)output_device_shape.emplace_back(device_shape);
123 }
124 input->set_user_data<runtime::OpRuntimeInfo>(
125 std::make_shared<runtime::OpRuntimeInfo>(formats, types, tensor_sizes, output_infer_shape, output_device_shape,
126 nullptr, std::vector<std::pair<device::KernelInfo *, size_t>>()));
127 }
128 }
129 } // namespace
130
output_format(size_t index) const131 std::string OpRuntimeInfo::output_format(size_t index) const {
132 if (index >= output_format_.size()) {
133 MS_LOG(EXCEPTION) << "Invalid index:" << index << " total output_format:" << output_format_.size();
134 }
135 return output_format_[index];
136 }
137
output_type(size_t index) const138 TypeId OpRuntimeInfo::output_type(size_t index) const {
139 if (index >= output_type_.size()) {
140 MS_LOG(EXCEPTION) << "Invalid index:" << index << " total output_type:" << output_type_.size();
141 }
142 return output_type_[index];
143 }
144
output_tensor_size(size_t index) const145 size_t OpRuntimeInfo::output_tensor_size(size_t index) const {
146 if (index >= output_tensor_size_.size()) {
147 MS_LOG(EXCEPTION) << "Invalid index:" << index << " total output_tensor_size:" << output_tensor_size_.size();
148 }
149 return output_tensor_size_[index];
150 }
151
output_infer_shape(size_t index) const152 const ShapeVector &OpRuntimeInfo::output_infer_shape(size_t index) const {
153 if (index >= output_infer_shape_.size()) {
154 MS_LOG(EXCEPTION) << "Invalid index:" << index << " total output_infer_shape:" << output_infer_shape_.size();
155 }
156 return output_infer_shape_[index];
157 }
158
output_device_shape(size_t index) const159 const ShapeVector &OpRuntimeInfo::output_device_shape(size_t index) const {
160 if (index >= output_device_shape_.size()) {
161 MS_LOG(EXCEPTION) << "Invalid index:" << index << " total output_infer_shape:" << output_device_shape_.size();
162 }
163 return output_device_shape_[index];
164 }
165
SetOutputTensorSize(size_t index,size_t tensor_size)166 void OpRuntimeInfo::SetOutputTensorSize(size_t index, size_t tensor_size) {
167 if (index >= output_tensor_size_.size()) {
168 MS_LOG(EXCEPTION) << "Invalid index:" << index << " total output_tensor_size:" << output_tensor_size_.size();
169 }
170 output_tensor_size_[index] = tensor_size;
171 }
172
SetOutputInferShape(size_t index,const ShapeVector & shape)173 void OpRuntimeInfo::SetOutputInferShape(size_t index, const ShapeVector &shape) {
174 if (index >= output_infer_shape_.size()) {
175 MS_LOG(EXCEPTION) << "Invalid index:" << index << " total output_infer_shape:" << output_infer_shape_.size();
176 }
177 output_infer_shape_[index] = shape;
178 }
179
SetOutputDeviceShape(size_t index,const ShapeVector & shape)180 void OpRuntimeInfo::SetOutputDeviceShape(size_t index, const ShapeVector &shape) {
181 if (index >= output_device_shape_.size()) {
182 MS_LOG(EXCEPTION) << "Invalid index:" << index << " total output_infer_shape:" << output_device_shape_.size();
183 }
184 output_device_shape_[index] = shape;
185 }
186
GetOutputDeviceAddress(size_t index) const187 device::DeviceAddressPtr OpRuntimeInfo::GetOutputDeviceAddress(size_t index) const {
188 MS_EXCEPTION_IF_NULL(kernel_info_);
189 return kernel_info_->GetMutableOutputAddr(index);
190 }
191
GetWorkspaceDeviceAddress(size_t index) const192 device::DeviceAddressPtr OpRuntimeInfo::GetWorkspaceDeviceAddress(size_t index) const {
193 MS_EXCEPTION_IF_NULL(kernel_info_);
194 return kernel_info_->GetMutableWorkspaceAddr(index);
195 }
196
GetInputDeviceAddress(size_t index) const197 device::DeviceAddressPtr OpRuntimeInfo::GetInputDeviceAddress(size_t index) const {
198 if (index >= input_kernel_infos_.size()) {
199 MS_LOG(ERROR) << "Output range! index:" << index << " input size:" << input_kernel_infos_.size();
200 return nullptr;
201 }
202
203 auto kernel_info_pair = input_kernel_infos_[index];
204 MS_EXCEPTION_IF_NULL(kernel_info_pair.first);
205 return kernel_info_pair.first->GetMutableOutputAddr(kernel_info_pair.second);
206 }
207
GetInputSize() const208 size_t OpRuntimeInfo::GetInputSize() const { return input_kernel_infos_.size(); }
209
GetOutputSize() const210 size_t OpRuntimeInfo::GetOutputSize() const {
211 MS_EXCEPTION_IF_NULL(kernel_info_);
212 return kernel_info_->output_address_list().size();
213 }
214
GetWorkspaceSize() const215 size_t OpRuntimeInfo::GetWorkspaceSize() const {
216 MS_EXCEPTION_IF_NULL(kernel_info_);
217 return kernel_info_->workspace_address_list().size();
218 }
219
GetKernelMod() const220 kernel::KernelMod *OpRuntimeInfo::GetKernelMod() const {
221 MS_EXCEPTION_IF_NULL(kernel_info_);
222 return kernel_info_->MutableKernelMod();
223 }
224
Resize(const AnfNodePtr & node)225 void OpRuntimeInfo::Resize(const AnfNodePtr &node) {
226 auto output_num = AnfAlgo::GetOutputTensorNum(node);
227 for (size_t i = 0; i < output_num; ++i) {
228 auto device_shape = AnfAlgo::GetOutputDeviceShape(node, i);
229 SetOutputInferShape(i, common::AnfAlgo::GetOutputInferShape(node, i));
230 SetOutputDeviceShape(i, device_shape);
231 SetOutputTensorSize(i,
232 OpRuntimeInfoGetOutputTensorMemSize(node, i, output_type(i), output_format(i), device_shape));
233 }
234 }
235
CacheGraphOpRuntimeInfo(const KernelGraphPtr & graph)236 void OpRuntimeInfo::CacheGraphOpRuntimeInfo(const KernelGraphPtr &graph) {
237 CacheForExecutionOrder(graph);
238 CacheForGraphInputs(graph);
239 }
240 } // namespace mindspore::runtime
241