1 /**
2 * Copyright 2022-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h"
18
19 #include <memory>
20 #include <algorithm>
21 #include <stack>
22 #include <set>
23 #include <string>
24 #include <vector>
25 #include <map>
26 #include <utility>
27 #include "mindspore/core/ops/framework_ops.h"
28 #include "include/backend/anf_runtime_algorithm.h"
29 #include "runtime/device/ms_device_shape_transfer.h"
30 #include "include/common/utils/anfalgo.h"
31 #include "include/common/utils/utils.h"
32 #include "utils/anf_utils.h"
33 #include "kernel/framework_utils.h"
34 #include "ops/op_def.h"
35 #include "utils/ms_context.h"
36 #include "abstract/ops/primitive_infer_map.h"
37 #include "mindspore/ccsrc/plugin/device/cpu/kernel/pyexecute/py_execute_cpu_kernel.h"
38 #include "include/common/profiler.h"
39 #include "ir/anf.h"
40 #include "ir/functor.h"
41 #include "backend/operator/ops_backend_infer_function.h"
42
43 namespace mindspore {
44 namespace opt::dynamic_shape {
45 namespace {
46 constexpr int64_t kInvalidShape = -2;
47
InferShapeForNopNode(const AnfNodePtr & input_node)48 void InferShapeForNopNode(const AnfNodePtr &input_node) {
49 MS_EXCEPTION_IF_NULL(input_node);
50 if (!common::AnfAlgo::IsNopNode(input_node)) {
51 MS_LOG(INFO) << "Input node is not a nop node, no need infer.";
52 return;
53 }
54 if (!common::AnfAlgo::IsNeedSkipNopOpExecution(input_node)) {
55 MS_LOG(INFO) << "The Nop node need execution, no need the InferShapeForNopNode.";
56 return;
57 }
58 MS_LOG(INFO) << "Infer shape for nop node.";
59 std::stack<AnfNodePtr> nop_road;
60 nop_road.push(input_node);
61
62 auto in_node = input_node;
63 while (true) {
64 auto input_node_with_idx = common::AnfAlgo::GetPrevNodeOutput(in_node, 0);
65 in_node = input_node_with_idx.first;
66 MS_EXCEPTION_IF_NULL(in_node);
67 if (common::AnfAlgo::IsNopNode(in_node)) {
68 nop_road.push(in_node);
69 } else {
70 break;
71 }
72 }
73
74 while (!nop_road.empty()) {
75 auto nop_node = nop_road.top();
76 MS_EXCEPTION_IF_NULL(nop_node);
77 AnfAlgo::InferShape(nop_node->cast<CNodePtr>());
78 nop_road.pop();
79 }
80 }
81
GetSequenceType(const abstract::AbstractSequencePtr & seq_abs)82 TypeId GetSequenceType(const abstract::AbstractSequencePtr &seq_abs) {
83 MS_EXCEPTION_IF_NULL(seq_abs);
84 auto elems = seq_abs->elements();
85 MS_EXCEPTION_IF_CHECK_FAIL(elems.size() >= 1, "Element size is less than 1.");
86 MS_EXCEPTION_IF_NULL(elems[0]);
87 if (!elems[0]->isa<abstract::AbstractScalar>() && !elems[0]->isa<abstract::AbstractTensor>()) {
88 MS_LOG(EXCEPTION) << "The 0'th element of sequence must be a scalar, but got:" << seq_abs->ToString();
89 }
90
91 auto fixed_type = (elems[0]->isa<abstract::AbstractScalar>()
92 ? elems[0]->BuildType()->type_id()
93 : elems[0]->cast<abstract::AbstractTensorPtr>()->element()->BuildType()->type_id());
94 for (size_t i = 1; i < elems.size(); i++) {
95 MS_EXCEPTION_IF_NULL(elems[i]);
96 if (!elems[i]->isa<abstract::AbstractScalar>() && !elems[i]->isa<abstract::AbstractTensor>()) {
97 MS_LOG(EXCEPTION) << "The " << i << "'th element of sequence must be a scalar, but got:" << elems[i]->ToString();
98 }
99 MS_EXCEPTION_IF_NULL(elems[i]->BuildType());
100 auto follow_type = (elems[i]->isa<abstract::AbstractScalar>()
101 ? elems[i]->BuildType()->type_id()
102 : elems[i]->cast<abstract::AbstractTensorPtr>()->element()->BuildType()->type_id());
103 if (fixed_type != follow_type) {
104 MS_LOG(EXCEPTION) << "Different type found between 0'th element[Type: " << fixed_type << "] and " << i
105 << "'th element[Type: " << follow_type << "]";
106 }
107 }
108 return fixed_type;
109 }
110
CreateTensorFromIndexedNode(const std::pair<AnfNodePtr,size_t> & input_node_with_index)111 tensor::TensorPtr CreateTensorFromIndexedNode(const std::pair<AnfNodePtr, size_t> &input_node_with_index) {
112 auto real_input = input_node_with_index.first;
113 MS_EXCEPTION_IF_NULL(real_input);
114 auto real_input_index = input_node_with_index.second;
115 auto abs = real_input->abstract();
116 MS_EXCEPTION_IF_NULL(abs);
117
118 ShapeVector shape;
119 TypeId type;
120 if (abs->isa<abstract::AbstractScalar>()) {
121 shape = {1};
122 MS_EXCEPTION_IF_NULL(abs->BuildType());
123 type = abs->BuildType()->type_id();
124 } else if (AnfAlgo::IsRealSquenceOutput(real_input)) {
125 auto seq_abs = abs->cast<abstract::AbstractSequencePtr>();
126 MS_EXCEPTION_IF_NULL(seq_abs);
127 auto elem_num = seq_abs->size();
128 if (elem_num == 0) {
129 MS_LOG(DEBUG) << "Empty sequence for node:" << real_input->fullname_with_scope();
130 return std::make_shared<tensor::Tensor>(TypeId::kNumberTypeInt64, ShapeVector({0}));
131 }
132 type = GetSequenceType(seq_abs);
133 shape = {SizeToLong(elem_num)};
134 } else if (abs->isa<abstract::AbstractTensor>() || abs->isa<abstract::AbstractSequence>()) {
135 shape = trans::GetRuntimePaddingShape(real_input, real_input_index);
136 if (real_input->isa<ValueNode>()) {
137 // the type of ValueNode in KernelInfo is kTypeUnknown
138 type = common::AnfAlgo::GetOutputInferDataType(real_input, real_input_index);
139 } else {
140 type = AnfAlgo::GetOutputDeviceDataType(real_input, real_input_index);
141 if (type == TypeId::kTypeUnknown) {
142 type = common::AnfAlgo::GetOutputInferDataType(real_input, real_input_index);
143 }
144 }
145 } else {
146 MS_LOG(EXCEPTION) << "For node:" << real_input->fullname_with_scope() << ", abstract(" << abs->ToString()
147 << ") is invalid.";
148 }
149
150 MS_LOG(DEBUG) << "Create tensor by node:" << input_node_with_index.first->DebugString()
151 << " index:" << input_node_with_index.second << " type:" << type << " shape:" << shape
152 << " abstract:" << abs->ToString();
153 return std::make_shared<tensor::Tensor>(type, shape);
154 }
155
CreateTensorMem(const std::pair<AnfNodePtr,size_t> & input_node_with_index,const AnfNodePtr & node,size_t i,void * args)156 tensor::TensorPtr CreateTensorMem(const std::pair<AnfNodePtr, size_t> &input_node_with_index, const AnfNodePtr &node,
157 size_t i, void *args) {
158 if (node != nullptr && common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimPyExecute)) {
159 MS_EXCEPTION_IF_NULL(args);
160 auto input_list = reinterpret_cast<std::vector<device::DeviceAddress *> *>(args);
161 MS_EXCEPTION_IF_NULL(input_list);
162 if (i >= input_list->size() || input_list->at(i) == nullptr) {
163 MS_LOG(EXCEPTION) << "Failed to get device address by input num:" << i << " for node:" << node->DebugString();
164 }
165 const auto &device_address = input_list->at(i);
166 MS_EXCEPTION_IF_NULL(device_address->kernel_tensor());
167 MS_LOG(DEBUG) << "input node:" << input_node_with_index.first->DebugString()
168 << " abstract:" << input_node_with_index.first->abstract()->ToString()
169 << " device address:" << device_address << " type id:" << device_address->kernel_tensor()->dtype_id()
170 << " shape vector:" << device_address->kernel_tensor()->GetShapeVector();
171 auto type_id = device_address->kernel_tensor()->dtype_id();
172 if (device_address->kernel_tensor()->GetType() != nullptr &&
173 ((device_address->kernel_tensor()->GetType()->isa<Tuple>() &&
174 device_address->kernel_tensor()->GetType()->cast<TuplePtr>()->size() == 0) ||
175 (device_address->kernel_tensor()->GetType()->isa<List>() &&
176 device_address->kernel_tensor()->GetType()->cast<ListPtr>()->size() == 0))) {
177 type_id = TypeId::kNumberTypeInt64;
178 }
179 return std::make_shared<tensor::Tensor>(type_id, device_address->kernel_tensor()->GetShapeVector());
180 }
181
182 return CreateTensorFromIndexedNode(input_node_with_index);
183 }
184
GetDependValueTensor(const AnfNodePtr & node,size_t i,const std::pair<AnfNodePtr,size_t> & input_node_with_index,bool skip_nop_node,void * args)185 tensor::TensorPtr GetDependValueTensor(const AnfNodePtr &node, size_t i,
186 const std::pair<AnfNodePtr, size_t> &input_node_with_index, bool skip_nop_node,
187 void *args) {
188 MS_EXCEPTION_IF_NULL(node);
189 MS_EXCEPTION_IF_NULL(input_node_with_index.first);
190 if (IsPrimitiveCNode(node, prim::kPrimPyExecute) && input_node_with_index.first->isa<ValueNode>()) {
191 const auto &value_node = input_node_with_index.first->cast<ValueNodePtr>();
192 MS_EXCEPTION_IF_NULL(value_node);
193 const auto &value = value_node->value();
194 MS_EXCEPTION_IF_NULL(value);
195 if (value->isa<tensor::Tensor>()) {
196 return value->cast<tensor::TensorPtr>();
197 } else if (value->isa<Scalar>()) {
198 return ScalarToTensor(value->cast<ScalarPtr>());
199 }
200 }
201 auto depended_value = CreateTensorMem(input_node_with_index, node, i, args);
202 MS_EXCEPTION_IF_NULL(depended_value);
203 // First use the data of args.
204 if (args != nullptr) {
205 auto input_device_address = reinterpret_cast<std::vector<device::DeviceAddress *> *>(args);
206 MS_EXCEPTION_IF_NULL(input_device_address);
207 if (i < input_device_address->size() && input_device_address->at(i) != nullptr) {
208 uint64_t start_time = 0;
209 PROFILER_START(start_time);
210 auto addr = reinterpret_cast<device::DeviceAddress *>(input_device_address->at(i));
211 MS_EXCEPTION_IF_NULL(addr);
212 auto node_idx = addr->node_index();
213 auto user_data = addr->user_data();
214 if (user_data != nullptr && user_data->has(kernel::PyExecuteOutputUserData::key)) {
215 auto addr_node = node_idx.first.lock();
216 MS_EXCEPTION_IF_NULL(addr_node);
217 auto out_addr = AnfAlgo::GetMutableOutputAddr(addr_node, node_idx.second, skip_nop_node);
218 depended_value->set_device_address(out_addr, false);
219 return depended_value;
220 }
221 MS_LOG(DEBUG) << "Get depend value tensor for node:" << node->DebugString() << " input index:" << i
222 << " input node:" << input_node_with_index.first->DebugString() << " index"
223 << input_node_with_index.second << " node addr:" << input_node_with_index.first
224 << " device_address:" << input_device_address->at(i)
225 << " type id:" << input_device_address->at(i)->type_id();
226 depended_value->data_sync_directly(input_device_address->at(i));
227 PROFILER_END(start_time, runtime::ProfilerModule::kKernel, runtime::ProfilerEvent::kKernelInferDataSync,
228 node->fullname_with_scope(), true);
229 return depended_value;
230 }
231 MS_LOG(WARNING) << "There is no valid data for " << i << " input of " << node->DebugString() << ", "
232 << node->fullname_with_scope();
233 }
234
235 // Second use the device address of node as fault-tolerant.
236 auto output_addr =
237 AnfAlgo::GetMutableOutputAddr(input_node_with_index.first, input_node_with_index.second, skip_nop_node);
238 MS_EXCEPTION_IF_NULL(output_addr);
239 if (output_addr != nullptr && output_addr->IsPtrValid()) {
240 // The second parameter must be false, otherwise the device address cannot be released and allocated, and the
241 // address size will be wrong in the dynamic shape scenario.
242 depended_value->set_device_address(output_addr, false);
243 uint64_t start_time = 0;
244 PROFILER_START(start_time);
245 // PyExecute using the data of user_data instead of address, so don't need to sync data form device./
246 if (IsPrimitiveCNode(input_node_with_index.first, prim::kPrimPyExecute)) {
247 MS_LOG(DEBUG) << "The input node is " << input_node_with_index.first->ToString()
248 << ", use user data instead of address.";
249 return depended_value;
250 }
251 MS_LOG(DEBUG) << "Get depend value tensor for node:" << node->DebugString() << " input index:" << i
252 << " input node:" << input_node_with_index.first->DebugString() << " index"
253 << input_node_with_index.second << " node addr:" << input_node_with_index.first
254 << " sync for device tensor:" << output_addr;
255 depended_value->data_sync();
256 PROFILER_END(start_time, runtime::ProfilerModule::kKernel, runtime::ProfilerEvent::kKernelInferDataSync,
257 node->fullname_with_scope(), true);
258 return depended_value;
259 }
260
261 MS_LOG(EXCEPTION) << "There is no valid data for " << i << " input of " << node->DebugString() << ", "
262 << node->fullname_with_scope();
263 }
264
MakeNewAbstractByScalar(const tensor::TensorPtr & depended_value)265 abstract::AbstractBasePtr MakeNewAbstractByScalar(const tensor::TensorPtr &depended_value) {
266 abstract::AbstractBasePtr new_abs;
267 MS_EXCEPTION_IF_NULL(depended_value);
268 MS_EXCEPTION_IF_NULL(depended_value->Dtype());
269 auto type = depended_value->Dtype()->type_id();
270 if (type == kNumberTypeInt32) {
271 auto tensor_data = reinterpret_cast<int32_t *>(depended_value->data_c());
272 MS_EXCEPTION_IF_NULL(tensor_data);
273 new_abs = std::make_shared<abstract::AbstractScalar>(*tensor_data);
274 } else if (type == kNumberTypeInt64) {
275 auto tensor_data = reinterpret_cast<int64_t *>(depended_value->data_c());
276 MS_EXCEPTION_IF_NULL(tensor_data);
277 new_abs = std::make_shared<abstract::AbstractScalar>(*tensor_data);
278 } else if (type == kNumberTypeFloat32) {
279 auto tensor_data = reinterpret_cast<float *>(depended_value->data_c());
280 MS_EXCEPTION_IF_NULL(tensor_data);
281 new_abs = std::make_shared<abstract::AbstractScalar>(*tensor_data);
282 } else if (type == kNumberTypeFloat64) {
283 auto tensor_data = reinterpret_cast<double *>(depended_value->data_c());
284 MS_EXCEPTION_IF_NULL(tensor_data);
285 new_abs = std::make_shared<abstract::AbstractScalar>(*tensor_data);
286 } else if (type == kNumberTypeBool) {
287 auto tensor_data = reinterpret_cast<bool *>(depended_value->data_c());
288 MS_EXCEPTION_IF_NULL(tensor_data);
289 new_abs = std::make_shared<abstract::AbstractScalar>(*tensor_data);
290 } else {
291 MS_LOG(EXCEPTION) << "Unsupported type: " << type;
292 }
293 return new_abs;
294 }
295
296 template <typename T>
MakeElemsByTensorValue(void * data,size_t size)297 abstract::AbstractBasePtrList MakeElemsByTensorValue(void *data, size_t size) {
298 MS_EXCEPTION_IF_NULL(data);
299 T *tensor_data = static_cast<T *>(data);
300 AbstractBasePtrList elems;
301 for (size_t i = 0; i < size; i++) {
302 auto scalar = std::make_shared<abstract::AbstractScalar>(tensor_data[i]);
303 (void)elems.emplace_back(scalar);
304 }
305 return elems;
306 }
307
MakeNewAbstractBySequence(const tensor::TensorPtr & depended_value,const abstract::AbstractBasePtr & input_abs)308 abstract::AbstractBasePtr MakeNewAbstractBySequence(const tensor::TensorPtr &depended_value,
309 const abstract::AbstractBasePtr &input_abs) {
310 abstract::AbstractBasePtr new_abs;
311 MS_EXCEPTION_IF_NULL(depended_value);
312 MS_EXCEPTION_IF_NULL(depended_value->Dtype());
313 MS_EXCEPTION_IF_NULL(input_abs);
314 auto type = depended_value->Dtype()->type_id();
315 AbstractBasePtrList elems;
316 switch (type) {
317 case kNumberTypeInt32: {
318 elems = MakeElemsByTensorValue<int32_t>(depended_value->data_c(), depended_value->DataSize());
319 break;
320 }
321 case kNumberTypeInt64: {
322 elems = MakeElemsByTensorValue<int64_t>(depended_value->data_c(), depended_value->DataSize());
323 break;
324 }
325 case kNumberTypeFloat32: {
326 elems = MakeElemsByTensorValue<float>(depended_value->data_c(), depended_value->DataSize());
327 break;
328 }
329 case kNumberTypeFloat64: {
330 elems = MakeElemsByTensorValue<double>(depended_value->data_c(), depended_value->DataSize());
331 break;
332 }
333 case kNumberTypeBool: {
334 elems = MakeElemsByTensorValue<bool>(depended_value->data_c(), depended_value->DataSize());
335 break;
336 }
337 default: {
338 MS_LOG(EXCEPTION) << "Unsupported type: " << type;
339 }
340 }
341 if (input_abs->isa<abstract::AbstractTuple>()) {
342 new_abs = std::make_shared<abstract::AbstractTuple>(elems);
343 } else if (input_abs->isa<abstract::AbstractList>()) {
344 new_abs = std::make_shared<abstract::AbstractList>(elems);
345 } else {
346 MS_LOG(EXCEPTION) << "Unsupported abstract type:" << input_abs->ToString();
347 }
348 MS_EXCEPTION_IF_NULL(new_abs);
349 new_abs->set_value(depended_value);
350 return new_abs;
351 }
352
MakeNewAbstract(const AnfNodePtr & input,const tensor::TensorPtr & depended_value,const size_t & input_index)353 abstract::AbstractBasePtr MakeNewAbstract(const AnfNodePtr &input, const tensor::TensorPtr &depended_value,
354 const size_t &input_index) {
355 MS_EXCEPTION_IF_NULL(input);
356 auto abs = input->abstract();
357 MS_EXCEPTION_IF_NULL(abs);
358 abstract::AbstractBasePtr new_abs;
359 if (abs->isa<abstract::AbstractTensor>()) {
360 new_abs = abs->Clone();
361 MS_EXCEPTION_IF_NULL(new_abs);
362 new_abs->set_value(depended_value);
363 } else if (abs->isa<abstract::AbstractScalar>()) {
364 new_abs = MakeNewAbstractByScalar(depended_value);
365 } else if (AnfAlgo::IsRealSquenceOutput(input)) {
366 new_abs = MakeNewAbstractBySequence(depended_value, abs);
367 } else if (abs->isa<abstract::AbstractSequence>()) {
368 auto abstract_seq = abs->cast<abstract::AbstractSequencePtr>();
369 MS_EXCEPTION_IF_NULL(abstract_seq);
370 MS_EXCEPTION_IF_CHECK_FAIL((input_index < abstract_seq->elements().size()), "Index is out of range.");
371 new_abs = abstract_seq->elements()[input_index]->Clone();
372 MS_EXCEPTION_IF_NULL(new_abs);
373 new_abs->set_value(depended_value);
374 } else {
375 MS_LOG(EXCEPTION) << "Unsupported abstract type:" << abs->ToString();
376 }
377 // Set user data for PyExecute infer.
378 if (input->has_user_data<kernel::PyExecuteOutputUserData>()) {
379 const auto &output_data = input->user_data<kernel::PyExecuteOutputUserData>();
380 MS_EXCEPTION_IF_NULL(new_abs);
381 new_abs->set_user_data<kernel::PyExecuteOutputUserData>(output_data);
382 }
383 auto depend_addr = depended_value->device_address();
384 if (depend_addr != nullptr) {
385 MS_LOG(DEBUG) << "Input node : " << input->DebugString() << ",use user_data instead of device address";
386 auto user_data = depend_addr->user_data();
387 if (user_data != nullptr) {
388 new_abs->set_user_data<kernel::PyExecuteOutputUserData>(
389 user_data->get<kernel::PyExecuteOutputUserData>(kernel::PyExecuteOutputUserData::key));
390 }
391 }
392 return new_abs;
393 }
394
InferShapeForPrimitive(const CNodePtr & cnode,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list,bool has_py_execute_data)395 void InferShapeForPrimitive(const CNodePtr &cnode, const PrimitivePtr &primitive,
396 const AbstractBasePtrList &args_spec_list, bool has_py_execute_data) {
397 MS_EXCEPTION_IF_NULL(cnode);
398 if (!has_py_execute_data && !IsPrimitiveCNode(cnode, prim::kPrimPyExecute)) {
399 // Pynative mode is rely on the origin abstract of cnode, so cannot modify the abstract inplace, clone from old
400 // abstract instead.
401 opt::CppInferShape(primitive, args_spec_list, cnode);
402 }
403 }
404
InferShape(const CNodePtr & cnode,std::map<uint32_t,tensor::TensorPtr> * depend_tensor_map,void * args)405 void InferShape(const CNodePtr &cnode, std::map<uint32_t, tensor::TensorPtr> *depend_tensor_map, void *args) {
406 MS_EXCEPTION_IF_NULL(cnode);
407 MS_EXCEPTION_IF_NULL(depend_tensor_map);
408 MS_LOG(DEBUG) << "InferShape start, node:" << cnode->fullname_with_scope();
409 std::set<int64_t> depend_list = abstract::GetValueDependArgIndices(cnode);
410
411 depend_tensor_map->clear();
412 auto &inputs = cnode->inputs();
413 if (inputs.empty()) {
414 MS_LOG(EXCEPTION) << "Invalid inputs.";
415 }
416 auto context = MsContext::GetInstance();
417 MS_EXCEPTION_IF_NULL(context);
418 AbstractBasePtrList args_spec_list;
419 auto input_size = common::AnfAlgo::GetInputTensorNum(cnode);
420 bool skip_nop_node = !context->get_param<bool>(MS_CTX_ENABLE_MINDRT);
421 bool has_py_execute_data = false;
422 kernel::PyExecuteOutputUserDataPtr list_user_data = nullptr;
423 std::vector<size_t> list_start_index;
424 for (size_t i = 0; i < input_size; i++) {
425 auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(cnode, i, false);
426 auto real_input = input_node_with_index.first;
427 auto real_input_index = input_node_with_index.second;
428
429 MS_EXCEPTION_IF_NULL(real_input);
430 if (skip_nop_node) {
431 InferShapeForNopNode(real_input);
432 }
433
434 if (depend_list.find(i) != depend_list.end()) {
435 auto depended_value = GetDependValueTensor(cnode, i, input_node_with_index, skip_nop_node, args);
436 auto ret2 = depend_tensor_map->try_emplace(i, depended_value);
437 if (!ret2.second) {
438 MS_LOG(EXCEPTION) << "Insert map failed.";
439 }
440
441 auto updated_abs = MakeNewAbstract(real_input, depended_value, real_input_index);
442 MS_EXCEPTION_IF_NULL(updated_abs);
443 MS_EXCEPTION_IF_NULL(real_input);
444 MS_EXCEPTION_IF_NULL(real_input->abstract());
445 if (updated_abs->has_user_data<kernel::PyExecuteOutputUserData>()) {
446 has_py_execute_data = true;
447 if (IsPrimitiveCNode(real_input, prim::kPrimPyExecute) &&
448 real_input->abstract()->isa<abstract::AbstractSequence>()) {
449 auto updated_abs_user_data = updated_abs->user_data<kernel::PyExecuteOutputUserData>();
450 if (list_user_data == nullptr || list_user_data != updated_abs_user_data) {
451 list_start_index.push_back(i);
452 list_user_data = updated_abs_user_data;
453 }
454 }
455 }
456 (void)args_spec_list.emplace_back(updated_abs);
457 } else {
458 auto abs = real_input->abstract();
459 MS_EXCEPTION_IF_NULL(abs);
460 MS_LOG(DEBUG) << "Real input node:" << real_input->DebugString() << " abs:" << abs->ToString()
461 << " index:" << real_input_index;
462 if (abs->isa<abstract::AbstractSequence>() && !AnfAlgo::IsRealSquenceOutput(real_input)) {
463 auto abs_seq = abs->cast<abstract::AbstractSequencePtr>();
464 MS_EXCEPTION_IF_NULL(abs_seq);
465 MS_EXCEPTION_IF_CHECK_FAIL((real_input_index < abs_seq->elements().size()), "Index is out of range.");
466 auto abs_index = abs_seq->elements()[real_input_index];
467 (void)args_spec_list.emplace_back(abs_index);
468 } else {
469 (void)args_spec_list.emplace_back(abs);
470 }
471 }
472 }
473 MS_EXCEPTION_IF_NULL(inputs[0]);
474 if (auto primitive = GetValueNode<PrimitivePtr>(inputs[0])) {
475 MS_EXCEPTION_IF_NULL(primitive);
476 (void)primitive->AddAttr(kAttrListStartIndex, MakeValue(list_start_index));
477 InferShapeForPrimitive(cnode, primitive, args_spec_list, has_py_execute_data);
478 } else {
479 MS_LOG(EXCEPTION) << "The first input of the cnode should be either a primitive or a function graph, but get: "
480 << inputs[0]->fullname_with_scope();
481 }
482 MS_LOG(DEBUG) << "InferShape end, node:" << cnode->fullname_with_scope();
483 }
484
IsCpuKernelMod(kernel::KernelModType kernel_mod_type)485 inline bool IsCpuKernelMod(kernel::KernelModType kernel_mod_type) {
486 return kernel_mod_type == kernel::KernelModType::NativeCpuKernelMod;
487 }
488 } // namespace
489
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)490 BaseShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
491 MS_EXCEPTION_IF_NULL(primitive);
492 if (primitive->HasAttr(kAttrInferShapeFunctor)) {
493 auto functor = primitive->GetAttr(kAttrInferShapeFunctor)->cast<InferShapeFunctorPtr>();
494 MS_EXCEPTION_IF_NULL(functor);
495 return functor->InferShape(input_args);
496 }
497 const auto &op_name = primitive->name();
498 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kKernel, runtime::ProfilerEvent::kKernelInferInner,
499 op_name, true);
500 auto shape_optional = abstract::InferShapeByFuncImpl(primitive, input_args, false);
501 if (shape_optional.has_value()) {
502 return shape_optional.value();
503 }
504
505 // The old register map for InferShape will be deleted in the future.
506 auto infer_impl = abstract::GetBackendPrimitiveInferImpl(primitive);
507 if (infer_impl.has_value()) {
508 auto infer = infer_impl.value();
509 if (infer.IsImplInferShapeAndType()) {
510 return infer.InferShape(primitive, input_args);
511 }
512 }
513 MS_LOG(EXCEPTION) << "The InferShape function of [" << op_name << "] is not defined.";
514 }
515
UpdateKernelTensorShape(const BaseShapePtr & base_shape,const std::vector<kernel::KernelTensor * > & output_kernel_tensors)516 void UpdateKernelTensorShape(const BaseShapePtr &base_shape,
517 const std::vector<kernel::KernelTensor *> &output_kernel_tensors) {
518 MS_EXCEPTION_IF_NULL(base_shape);
519 size_t output_num = output_kernel_tensors.size();
520 if (output_num > 1) {
521 auto sequence_shape = base_shape->cast<abstract::SequenceShapePtr>();
522 MS_EXCEPTION_IF_NULL(sequence_shape);
523 const auto &shapes = sequence_shape->shape();
524 if (shapes.size() != output_num) {
525 MS_LOG(EXCEPTION) << "Invalid SequenceShape, expected elements number: " << output_num
526 << ", but got: " << shapes.size();
527 }
528 for (size_t i = 0; i < output_num; i++) {
529 const auto &kernel_tensor = output_kernel_tensors[i];
530 MS_EXCEPTION_IF_NULL(kernel_tensor);
531 kernel_tensor->SetShape(shapes[i]);
532 }
533 } else if (output_num == 1) {
534 const auto &kernel_tensor = output_kernel_tensors[0];
535 MS_EXCEPTION_IF_NULL(kernel_tensor);
536 auto sequence_shape = base_shape->cast<abstract::SequenceShapePtr>();
537 if ((kernel_tensor->type_id() != kObjectTypeTuple && kernel_tensor->type_id() != kObjectTypeList) &&
538 sequence_shape != nullptr) {
539 // For the operator prototype whose output is of type Tuple, the back-end operator is expanded as Tensors, and for
540 // single-output scenarios, the InferShape result is TupleShape, and the back-end needs to expand it to
541 // TensorShape. For example, the output of the split operator is only a Tensor scene.
542 const auto &shapes = sequence_shape->shape();
543 if (shapes.size() != 1) {
544 MS_LOG(EXCEPTION) << "Invalid SequenceShape, expected elements number: " << 1 << ", but got: " << shapes.size();
545 }
546
547 kernel_tensor->SetShape(shapes[0]);
548 } else {
549 kernel_tensor->SetShape(base_shape);
550 }
551 }
552 }
553
InferShapeAndType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)554 abstract::AbstractBasePtr InferShapeAndType(const PrimitivePtr &primitive,
555 const std::vector<AbstractBasePtr> &input_args) {
556 MS_EXCEPTION_IF_NULL(primitive);
557 const auto &op_name = primitive->name();
558 auto infer_impl = abstract::GetBackendPrimitiveInferImpl(primitive);
559 if (infer_impl.has_value()) {
560 auto infer = infer_impl.value();
561 if (infer.IsImplInferShapeAndType()) {
562 return infer.InferShapeAndType(nullptr, primitive, input_args);
563 }
564 }
565 MS_LOG(EXCEPTION) << "The InferShape function of [" << op_name << "] is not defined.";
566 }
567
UpdateKernelTensorType(const TypePtr & type,const std::vector<kernel::KernelTensor * > & output_kernel_tensors)568 void UpdateKernelTensorType(const TypePtr &type, const std::vector<kernel::KernelTensor *> &output_kernel_tensors) {
569 MS_EXCEPTION_IF_NULL(type);
570 if (output_kernel_tensors.size() != 1) {
571 MS_LOG(EXCEPTION) << "Invalid output size:" << output_kernel_tensors.size();
572 }
573
574 const auto &kernel_tensor = output_kernel_tensors[0];
575 MS_EXCEPTION_IF_NULL(kernel_tensor);
576 kernel_tensor->SetType(type);
577 }
578
IsRealCNode(const BaseRef & n)579 bool IsRealCNode(const BaseRef &n) {
580 if (utils::isa<CNodePtr>(n)) {
581 CNodePtr cnode = utils::cast<CNodePtr>(n);
582 return AnfUtils::IsRealKernel(cnode);
583 }
584 return false;
585 }
586
GenInferNode(const AnfNodePtr & node)587 AnfNodePtr GenInferNode(const AnfNodePtr &node) {
588 MS_EXCEPTION_IF_NULL(node);
589 auto cnode = node->cast<CNodePtr>();
590 MS_EXCEPTION_IF_NULL(cnode);
591 auto infer_node = AnfUtils::NewInferActorNode([cnode](void *args) { InferOp(cnode, args); }, cnode);
592 MS_EXCEPTION_IF_NULL(infer_node);
593 infer_node->set_kernel_info(std::make_shared<device::KernelInfo>());
594 return infer_node;
595 }
596
GenInitNode(const AnfNodePtr & node)597 AnfNodePtr GenInitNode(const AnfNodePtr &node) {
598 MS_EXCEPTION_IF_NULL(node);
599 auto cnode = node->cast<CNodePtr>();
600 MS_EXCEPTION_IF_NULL(cnode);
601
602 auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
603 MS_EXCEPTION_IF_NULL(kernel_mod);
604 AnfUtils::CustomActorCallback actor_func = [kernel_mod, cnode](void *) {
605 auto inputs = AnfAlgo::GetOrCreateAllInputKernelTensors(cnode);
606 auto outputs = AnfAlgo::GetOrCreateAllOutputKernelTensors(cnode);
607 if (kernel_mod->Resize(inputs, outputs) == static_cast<int>(kernel::KRET_RESIZE_FAILED)) {
608 MS_LOG(EXCEPTION) << "Node " << cnode->fullname_with_scope() << " Resize failed.";
609 }
610 };
611
612 auto init_node = AnfUtils::NewInitActorNode(actor_func, cnode);
613 MS_EXCEPTION_IF_NULL(init_node);
614 init_node->set_kernel_info(std::make_shared<device::KernelInfo>());
615 return init_node;
616 }
617
InferOp(const CNodePtr & cnode,void * args)618 void InferOp(const CNodePtr &cnode, void *args) {
619 MS_EXCEPTION_IF_NULL(cnode);
620 auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
621 MS_EXCEPTION_IF_NULL(kernel_mod);
622
623 kernel::KernelArgs kernel_args;
624 MS_LOG(DEBUG) << "infer shape for node:" << cnode->fullname_with_scope();
625 InferShape(cnode, &kernel_args.depend_tensor_map, args);
626 auto kernel_mod_type = kernel_mod->GetKernelModType();
627 auto update = kernel::AbstractArgsFromCNode(cnode);
628 update.depend_tensor_map = std::move(kernel_args.depend_tensor_map);
629 kernel::SetInputsByDependMap(update.depend_tensor_map, &update.inputs, IsCpuKernelMod(kernel_mod_type));
630 kernel::SetArgsToCNode(cnode, update);
631 }
632
Instance()633 CustomActorNodeManager &CustomActorNodeManager::Instance() {
634 static CustomActorNodeManager instance{};
635 return instance;
636 }
637 } // namespace opt::dynamic_shape
638 } // namespace mindspore
639