• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/pynative/op_runtime_info.h"
18 
19 #include <utility>
20 
21 #include "include/backend/anf_runtime_algorithm.h"
22 #include "include/common/utils/anfalgo.h"
23 #include "runtime/device/ms_device_shape_transfer.h"
24 
25 namespace mindspore::runtime {
26 namespace {
OpRuntimeInfoGetOutputTensorMemSize(const AnfNodePtr & node,size_t output_index,TypeId type,const std::string & format,const ShapeVector & device_shape)27 size_t OpRuntimeInfoGetOutputTensorMemSize(const AnfNodePtr &node, size_t output_index, TypeId type,
28                                            const std::string &format, const ShapeVector &device_shape) {
29   MS_EXCEPTION_IF_NULL(node);
30   if (output_index >= AnfAlgo::GetOutputTensorNum(node)) {
31     MS_EXCEPTION(ArgumentError) << "output index [" << output_index << "] large than the output size ["
32                                 << AnfAlgo::GetOutputTensorNum(node) << "] of node!";
33   }
34   size_t type_size = GetTypeByte(TypeIdToType(type));
35   auto shape = device_shape;
36   if (IsDynamic(shape)) {
37     auto max_shape = common::AnfAlgo::GetOutputMaxShape(node, output_index);
38     if (!max_shape.empty()) {
39       shape = max_shape;
40       MS_LOG(DEBUG) << "shape[" << shape << "] is dynamic, using max_shape[" << max_shape << "] instead.";
41     } else {
42       shape = {1};
43       MS_LOG(DEBUG) << "shape[" << shape << "] is dynamic, set default to {1}";
44     }
45   }
46   if (shape.empty() && format != kOpFormat_DEFAULT) {
47     shape = trans::PaddingShape(shape, format, AnfAlgo::GetOutputReshapeType(node, output_index), node);
48     shape = trans::TransShapeToDevice(shape, format, node, output_index, type);
49   }
50   // scalar's output shape is a empty vector
51   size_t tensor_size = type_size * SizeOf(shape);
52   return tensor_size;
53 }
54 
CacheForExecutionOrder(const KernelGraphPtr & graph)55 void CacheForExecutionOrder(const KernelGraphPtr &graph) {
56   MS_EXCEPTION_IF_NULL(graph);
57   const auto &nodes = graph->execution_order();
58   for (auto const &node : nodes) {
59     std::vector<std::string> formats;
60     std::vector<TypeId> types;
61     std::vector<size_t> tensor_sizes;
62     std::vector<ShapeVector> output_infer_shape;
63     std::vector<ShapeVector> output_device_shape;
64     auto output_num = AnfAlgo::GetOutputTensorNum(node);
65     for (size_t i = 0; i < output_num; ++i) {
66       std::string output_format = AnfAlgo::GetOutputFormat(node, i);
67       auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i);
68       auto device_shape = AnfAlgo::GetOutputDeviceShape(node, i);
69       auto tensor_size = OpRuntimeInfoGetOutputTensorMemSize(node, i, output_type, output_format, device_shape);
70       (void)formats.emplace_back(output_format);
71       (void)types.emplace_back(output_type);
72       (void)tensor_sizes.emplace_back(tensor_size);
73       (void)output_infer_shape.emplace_back(common::AnfAlgo::GetOutputInferShape(node, i));
74       (void)output_device_shape.emplace_back(device_shape);
75     }
76 
77     // For input
78     std::vector<std::pair<device::KernelInfo *, size_t>> input_kernel_infos;
79     auto input_size = common::AnfAlgo::GetInputTensorNum(node);
80     for (size_t i = 0; i < input_size; ++i) {
81       session::KernelWithIndex kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(node, i, true);
82       MS_EXCEPTION_IF_NULL(kernel_with_index.first);
83       (void)input_kernel_infos.emplace_back(dynamic_cast<device::KernelInfo *>(kernel_with_index.first->kernel_info()),
84                                             kernel_with_index.second);
85     }
86 
87     // For workspace and output
88     MS_EXCEPTION_IF_NULL(node);
89     auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
90 
91     node->set_user_data<runtime::OpRuntimeInfo>(std::make_shared<runtime::OpRuntimeInfo>(
92       formats, types, tensor_sizes, output_infer_shape, output_device_shape, kernel_info, input_kernel_infos));
93   }
94 }
95 
CacheForGraphInputs(const KernelGraphPtr & graph)96 void CacheForGraphInputs(const KernelGraphPtr &graph) {
97   MS_EXCEPTION_IF_NULL(graph);
98   const auto &inputs = graph->inputs();
99   for (const auto &input : inputs) {
100     MS_EXCEPTION_IF_NULL(input);
101     if (!input->isa<Parameter>()) {
102       continue;
103     }
104     std::vector<std::string> formats;
105     std::vector<TypeId> types;
106     std::vector<size_t> tensor_sizes;
107     std::vector<ShapeVector> output_infer_shape;
108     std::vector<ShapeVector> output_device_shape;
109     auto output_size = AnfAlgo::GetOutputTensorNum(input);
110     for (size_t index = 0; index < output_size; index++) {
111       auto format = AnfAlgo::GetOutputFormat(input, index);
112       auto type_id = AnfAlgo::GetOutputDeviceDataType(input, index);
113       if (type_id == kTypeUnknown) {
114         type_id = common::AnfAlgo::GetOutputInferDataType(input, index);
115       }
116       auto device_shape = AnfAlgo::GetOutputDeviceShape(input, index);
117       auto tensor_size = OpRuntimeInfoGetOutputTensorMemSize(input, index, type_id, format, device_shape);
118       (void)formats.emplace_back(format);
119       (void)types.emplace_back(type_id);
120       (void)tensor_sizes.emplace_back(tensor_size);
121       (void)output_infer_shape.emplace_back(common::AnfAlgo::GetOutputInferShape(input, index));
122       (void)output_device_shape.emplace_back(device_shape);
123     }
124     input->set_user_data<runtime::OpRuntimeInfo>(
125       std::make_shared<runtime::OpRuntimeInfo>(formats, types, tensor_sizes, output_infer_shape, output_device_shape,
126                                                nullptr, std::vector<std::pair<device::KernelInfo *, size_t>>()));
127   }
128 }
129 }  // namespace
130 
output_format(size_t index) const131 std::string OpRuntimeInfo::output_format(size_t index) const {
132   if (index >= output_format_.size()) {
133     MS_LOG(EXCEPTION) << "Invalid index:" << index << " total output_format:" << output_format_.size();
134   }
135   return output_format_[index];
136 }
137 
output_type(size_t index) const138 TypeId OpRuntimeInfo::output_type(size_t index) const {
139   if (index >= output_type_.size()) {
140     MS_LOG(EXCEPTION) << "Invalid index:" << index << " total output_type:" << output_type_.size();
141   }
142   return output_type_[index];
143 }
144 
output_tensor_size(size_t index) const145 size_t OpRuntimeInfo::output_tensor_size(size_t index) const {
146   if (index >= output_tensor_size_.size()) {
147     MS_LOG(EXCEPTION) << "Invalid index:" << index << " total output_tensor_size:" << output_tensor_size_.size();
148   }
149   return output_tensor_size_[index];
150 }
151 
output_infer_shape(size_t index) const152 const ShapeVector &OpRuntimeInfo::output_infer_shape(size_t index) const {
153   if (index >= output_infer_shape_.size()) {
154     MS_LOG(EXCEPTION) << "Invalid index:" << index << " total output_infer_shape:" << output_infer_shape_.size();
155   }
156   return output_infer_shape_[index];
157 }
158 
output_device_shape(size_t index) const159 const ShapeVector &OpRuntimeInfo::output_device_shape(size_t index) const {
160   if (index >= output_device_shape_.size()) {
161     MS_LOG(EXCEPTION) << "Invalid index:" << index << " total output_infer_shape:" << output_device_shape_.size();
162   }
163   return output_device_shape_[index];
164 }
165 
SetOutputTensorSize(size_t index,size_t tensor_size)166 void OpRuntimeInfo::SetOutputTensorSize(size_t index, size_t tensor_size) {
167   if (index >= output_tensor_size_.size()) {
168     MS_LOG(EXCEPTION) << "Invalid index:" << index << " total output_tensor_size:" << output_tensor_size_.size();
169   }
170   output_tensor_size_[index] = tensor_size;
171 }
172 
SetOutputInferShape(size_t index,const ShapeVector & shape)173 void OpRuntimeInfo::SetOutputInferShape(size_t index, const ShapeVector &shape) {
174   if (index >= output_infer_shape_.size()) {
175     MS_LOG(EXCEPTION) << "Invalid index:" << index << " total output_infer_shape:" << output_infer_shape_.size();
176   }
177   output_infer_shape_[index] = shape;
178 }
179 
SetOutputDeviceShape(size_t index,const ShapeVector & shape)180 void OpRuntimeInfo::SetOutputDeviceShape(size_t index, const ShapeVector &shape) {
181   if (index >= output_device_shape_.size()) {
182     MS_LOG(EXCEPTION) << "Invalid index:" << index << " total output_infer_shape:" << output_device_shape_.size();
183   }
184   output_device_shape_[index] = shape;
185 }
186 
GetOutputDeviceAddress(size_t index) const187 device::DeviceAddressPtr OpRuntimeInfo::GetOutputDeviceAddress(size_t index) const {
188   MS_EXCEPTION_IF_NULL(kernel_info_);
189   return kernel_info_->GetMutableOutputAddr(index);
190 }
191 
GetWorkspaceDeviceAddress(size_t index) const192 device::DeviceAddressPtr OpRuntimeInfo::GetWorkspaceDeviceAddress(size_t index) const {
193   MS_EXCEPTION_IF_NULL(kernel_info_);
194   return kernel_info_->GetMutableWorkspaceAddr(index);
195 }
196 
GetInputDeviceAddress(size_t index) const197 device::DeviceAddressPtr OpRuntimeInfo::GetInputDeviceAddress(size_t index) const {
198   if (index >= input_kernel_infos_.size()) {
199     MS_LOG(ERROR) << "Output range! index:" << index << " input size:" << input_kernel_infos_.size();
200     return nullptr;
201   }
202 
203   auto kernel_info_pair = input_kernel_infos_[index];
204   MS_EXCEPTION_IF_NULL(kernel_info_pair.first);
205   return kernel_info_pair.first->GetMutableOutputAddr(kernel_info_pair.second);
206 }
207 
GetInputSize() const208 size_t OpRuntimeInfo::GetInputSize() const { return input_kernel_infos_.size(); }
209 
GetOutputSize() const210 size_t OpRuntimeInfo::GetOutputSize() const {
211   MS_EXCEPTION_IF_NULL(kernel_info_);
212   return kernel_info_->output_address_list().size();
213 }
214 
GetWorkspaceSize() const215 size_t OpRuntimeInfo::GetWorkspaceSize() const {
216   MS_EXCEPTION_IF_NULL(kernel_info_);
217   return kernel_info_->workspace_address_list().size();
218 }
219 
GetKernelMod() const220 kernel::KernelMod *OpRuntimeInfo::GetKernelMod() const {
221   MS_EXCEPTION_IF_NULL(kernel_info_);
222   return kernel_info_->MutableKernelMod();
223 }
224 
Resize(const AnfNodePtr & node)225 void OpRuntimeInfo::Resize(const AnfNodePtr &node) {
226   auto output_num = AnfAlgo::GetOutputTensorNum(node);
227   for (size_t i = 0; i < output_num; ++i) {
228     auto device_shape = AnfAlgo::GetOutputDeviceShape(node, i);
229     SetOutputInferShape(i, common::AnfAlgo::GetOutputInferShape(node, i));
230     SetOutputDeviceShape(i, device_shape);
231     SetOutputTensorSize(i,
232                         OpRuntimeInfoGetOutputTensorMemSize(node, i, output_type(i), output_format(i), device_shape));
233   }
234 }
235 
CacheGraphOpRuntimeInfo(const KernelGraphPtr & graph)236 void OpRuntimeInfo::CacheGraphOpRuntimeInfo(const KernelGraphPtr &graph) {
237   CacheForExecutionOrder(graph);
238   CacheForGraphInputs(graph);
239 }
240 }  // namespace mindspore::runtime
241