• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h"
18 
19 #include <memory>
20 #include <algorithm>
21 #include <stack>
22 #include <set>
23 #include <string>
24 #include <vector>
25 #include <map>
26 #include <utility>
27 #include "mindspore/core/ops/framework_ops.h"
28 #include "include/backend/anf_runtime_algorithm.h"
29 #include "runtime/device/ms_device_shape_transfer.h"
30 #include "include/common/utils/anfalgo.h"
31 #include "include/common/utils/utils.h"
32 #include "utils/anf_utils.h"
33 #include "kernel/framework_utils.h"
34 #include "ops/op_def.h"
35 #include "utils/ms_context.h"
36 #include "abstract/ops/primitive_infer_map.h"
37 #include "mindspore/ccsrc/plugin/device/cpu/kernel/pyexecute/py_execute_cpu_kernel.h"
38 #include "include/common/profiler.h"
39 #include "ir/anf.h"
40 #include "ir/functor.h"
41 #include "backend/operator/ops_backend_infer_function.h"
42 
43 namespace mindspore {
44 namespace opt::dynamic_shape {
45 namespace {
46 constexpr int64_t kInvalidShape = -2;
47 
InferShapeForNopNode(const AnfNodePtr & input_node)48 void InferShapeForNopNode(const AnfNodePtr &input_node) {
49   MS_EXCEPTION_IF_NULL(input_node);
50   if (!common::AnfAlgo::IsNopNode(input_node)) {
51     MS_LOG(INFO) << "Input node is not a nop node, no need infer.";
52     return;
53   }
54   if (!common::AnfAlgo::IsNeedSkipNopOpExecution(input_node)) {
55     MS_LOG(INFO) << "The Nop node need execution, no need the InferShapeForNopNode.";
56     return;
57   }
58   MS_LOG(INFO) << "Infer shape for nop node.";
59   std::stack<AnfNodePtr> nop_road;
60   nop_road.push(input_node);
61 
62   auto in_node = input_node;
63   while (true) {
64     auto input_node_with_idx = common::AnfAlgo::GetPrevNodeOutput(in_node, 0);
65     in_node = input_node_with_idx.first;
66     MS_EXCEPTION_IF_NULL(in_node);
67     if (common::AnfAlgo::IsNopNode(in_node)) {
68       nop_road.push(in_node);
69     } else {
70       break;
71     }
72   }
73 
74   while (!nop_road.empty()) {
75     auto nop_node = nop_road.top();
76     MS_EXCEPTION_IF_NULL(nop_node);
77     AnfAlgo::InferShape(nop_node->cast<CNodePtr>());
78     nop_road.pop();
79   }
80 }
81 
GetSequenceType(const abstract::AbstractSequencePtr & seq_abs)82 TypeId GetSequenceType(const abstract::AbstractSequencePtr &seq_abs) {
83   MS_EXCEPTION_IF_NULL(seq_abs);
84   auto elems = seq_abs->elements();
85   MS_EXCEPTION_IF_CHECK_FAIL(elems.size() >= 1, "Element size is less than 1.");
86   MS_EXCEPTION_IF_NULL(elems[0]);
87   if (!elems[0]->isa<abstract::AbstractScalar>() && !elems[0]->isa<abstract::AbstractTensor>()) {
88     MS_LOG(EXCEPTION) << "The 0'th element of sequence must be a scalar, but got:" << seq_abs->ToString();
89   }
90 
91   auto fixed_type = (elems[0]->isa<abstract::AbstractScalar>()
92                        ? elems[0]->BuildType()->type_id()
93                        : elems[0]->cast<abstract::AbstractTensorPtr>()->element()->BuildType()->type_id());
94   for (size_t i = 1; i < elems.size(); i++) {
95     MS_EXCEPTION_IF_NULL(elems[i]);
96     if (!elems[i]->isa<abstract::AbstractScalar>() && !elems[i]->isa<abstract::AbstractTensor>()) {
97       MS_LOG(EXCEPTION) << "The " << i << "'th element of sequence must be a scalar, but got:" << elems[i]->ToString();
98     }
99     MS_EXCEPTION_IF_NULL(elems[i]->BuildType());
100     auto follow_type = (elems[i]->isa<abstract::AbstractScalar>()
101                           ? elems[i]->BuildType()->type_id()
102                           : elems[i]->cast<abstract::AbstractTensorPtr>()->element()->BuildType()->type_id());
103     if (fixed_type != follow_type) {
104       MS_LOG(EXCEPTION) << "Different type found between 0'th element[Type: " << fixed_type << "] and " << i
105                         << "'th element[Type: " << follow_type << "]";
106     }
107   }
108   return fixed_type;
109 }
110 
CreateTensorFromIndexedNode(const std::pair<AnfNodePtr,size_t> & input_node_with_index)111 tensor::TensorPtr CreateTensorFromIndexedNode(const std::pair<AnfNodePtr, size_t> &input_node_with_index) {
112   auto real_input = input_node_with_index.first;
113   MS_EXCEPTION_IF_NULL(real_input);
114   auto real_input_index = input_node_with_index.second;
115   auto abs = real_input->abstract();
116   MS_EXCEPTION_IF_NULL(abs);
117 
118   ShapeVector shape;
119   TypeId type;
120   if (abs->isa<abstract::AbstractScalar>()) {
121     shape = {1};
122     MS_EXCEPTION_IF_NULL(abs->BuildType());
123     type = abs->BuildType()->type_id();
124   } else if (AnfAlgo::IsRealSquenceOutput(real_input)) {
125     auto seq_abs = abs->cast<abstract::AbstractSequencePtr>();
126     MS_EXCEPTION_IF_NULL(seq_abs);
127     auto elem_num = seq_abs->size();
128     if (elem_num == 0) {
129       MS_LOG(DEBUG) << "Empty sequence for node:" << real_input->fullname_with_scope();
130       return std::make_shared<tensor::Tensor>(TypeId::kNumberTypeInt64, ShapeVector({0}));
131     }
132     type = GetSequenceType(seq_abs);
133     shape = {SizeToLong(elem_num)};
134   } else if (abs->isa<abstract::AbstractTensor>() || abs->isa<abstract::AbstractSequence>()) {
135     shape = trans::GetRuntimePaddingShape(real_input, real_input_index);
136     if (real_input->isa<ValueNode>()) {
137       // the type of ValueNode in KernelInfo is kTypeUnknown
138       type = common::AnfAlgo::GetOutputInferDataType(real_input, real_input_index);
139     } else {
140       type = AnfAlgo::GetOutputDeviceDataType(real_input, real_input_index);
141       if (type == TypeId::kTypeUnknown) {
142         type = common::AnfAlgo::GetOutputInferDataType(real_input, real_input_index);
143       }
144     }
145   } else {
146     MS_LOG(EXCEPTION) << "For node:" << real_input->fullname_with_scope() << ", abstract(" << abs->ToString()
147                       << ") is invalid.";
148   }
149 
150   MS_LOG(DEBUG) << "Create tensor by node:" << input_node_with_index.first->DebugString()
151                 << " index:" << input_node_with_index.second << " type:" << type << " shape:" << shape
152                 << " abstract:" << abs->ToString();
153   return std::make_shared<tensor::Tensor>(type, shape);
154 }
155 
CreateTensorMem(const std::pair<AnfNodePtr,size_t> & input_node_with_index,const AnfNodePtr & node,size_t i,void * args)156 tensor::TensorPtr CreateTensorMem(const std::pair<AnfNodePtr, size_t> &input_node_with_index, const AnfNodePtr &node,
157                                   size_t i, void *args) {
158   if (node != nullptr && common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimPyExecute)) {
159     MS_EXCEPTION_IF_NULL(args);
160     auto input_list = reinterpret_cast<std::vector<device::DeviceAddress *> *>(args);
161     MS_EXCEPTION_IF_NULL(input_list);
162     if (i >= input_list->size() || input_list->at(i) == nullptr) {
163       MS_LOG(EXCEPTION) << "Failed to get device address by input num:" << i << " for node:" << node->DebugString();
164     }
165     const auto &device_address = input_list->at(i);
166     MS_EXCEPTION_IF_NULL(device_address->kernel_tensor());
167     MS_LOG(DEBUG) << "input node:" << input_node_with_index.first->DebugString()
168                   << " abstract:" << input_node_with_index.first->abstract()->ToString()
169                   << " device address:" << device_address << " type id:" << device_address->kernel_tensor()->dtype_id()
170                   << " shape vector:" << device_address->kernel_tensor()->GetShapeVector();
171     auto type_id = device_address->kernel_tensor()->dtype_id();
172     if (device_address->kernel_tensor()->GetType() != nullptr &&
173         ((device_address->kernel_tensor()->GetType()->isa<Tuple>() &&
174           device_address->kernel_tensor()->GetType()->cast<TuplePtr>()->size() == 0) ||
175          (device_address->kernel_tensor()->GetType()->isa<List>() &&
176           device_address->kernel_tensor()->GetType()->cast<ListPtr>()->size() == 0))) {
177       type_id = TypeId::kNumberTypeInt64;
178     }
179     return std::make_shared<tensor::Tensor>(type_id, device_address->kernel_tensor()->GetShapeVector());
180   }
181 
182   return CreateTensorFromIndexedNode(input_node_with_index);
183 }
184 
GetDependValueTensor(const AnfNodePtr & node,size_t i,const std::pair<AnfNodePtr,size_t> & input_node_with_index,bool skip_nop_node,void * args)185 tensor::TensorPtr GetDependValueTensor(const AnfNodePtr &node, size_t i,
186                                        const std::pair<AnfNodePtr, size_t> &input_node_with_index, bool skip_nop_node,
187                                        void *args) {
188   MS_EXCEPTION_IF_NULL(node);
189   MS_EXCEPTION_IF_NULL(input_node_with_index.first);
190   if (IsPrimitiveCNode(node, prim::kPrimPyExecute) && input_node_with_index.first->isa<ValueNode>()) {
191     const auto &value_node = input_node_with_index.first->cast<ValueNodePtr>();
192     MS_EXCEPTION_IF_NULL(value_node);
193     const auto &value = value_node->value();
194     MS_EXCEPTION_IF_NULL(value);
195     if (value->isa<tensor::Tensor>()) {
196       return value->cast<tensor::TensorPtr>();
197     } else if (value->isa<Scalar>()) {
198       return ScalarToTensor(value->cast<ScalarPtr>());
199     }
200   }
201   auto depended_value = CreateTensorMem(input_node_with_index, node, i, args);
202   MS_EXCEPTION_IF_NULL(depended_value);
203   // First use the data of args.
204   if (args != nullptr) {
205     auto input_device_address = reinterpret_cast<std::vector<device::DeviceAddress *> *>(args);
206     MS_EXCEPTION_IF_NULL(input_device_address);
207     if (i < input_device_address->size() && input_device_address->at(i) != nullptr) {
208       uint64_t start_time = 0;
209       PROFILER_START(start_time);
210       auto addr = reinterpret_cast<device::DeviceAddress *>(input_device_address->at(i));
211       MS_EXCEPTION_IF_NULL(addr);
212       auto node_idx = addr->node_index();
213       auto user_data = addr->user_data();
214       if (user_data != nullptr && user_data->has(kernel::PyExecuteOutputUserData::key)) {
215         auto addr_node = node_idx.first.lock();
216         MS_EXCEPTION_IF_NULL(addr_node);
217         auto out_addr = AnfAlgo::GetMutableOutputAddr(addr_node, node_idx.second, skip_nop_node);
218         depended_value->set_device_address(out_addr, false);
219         return depended_value;
220       }
221       MS_LOG(DEBUG) << "Get depend value tensor for node:" << node->DebugString() << " input index:" << i
222                     << " input node:" << input_node_with_index.first->DebugString() << " index"
223                     << input_node_with_index.second << " node addr:" << input_node_with_index.first
224                     << " device_address:" << input_device_address->at(i)
225                     << " type id:" << input_device_address->at(i)->type_id();
226       depended_value->data_sync_directly(input_device_address->at(i));
227       PROFILER_END(start_time, runtime::ProfilerModule::kKernel, runtime::ProfilerEvent::kKernelInferDataSync,
228                    node->fullname_with_scope(), true);
229       return depended_value;
230     }
231     MS_LOG(WARNING) << "There is no valid data for " << i << " input of " << node->DebugString() << ", "
232                     << node->fullname_with_scope();
233   }
234 
235   // Second use the device address of node as fault-tolerant.
236   auto output_addr =
237     AnfAlgo::GetMutableOutputAddr(input_node_with_index.first, input_node_with_index.second, skip_nop_node);
238   MS_EXCEPTION_IF_NULL(output_addr);
239   if (output_addr != nullptr && output_addr->IsPtrValid()) {
240     // The second parameter must be false, otherwise the device address cannot be released and allocated, and the
241     // address size will be wrong in the dynamic shape scenario.
242     depended_value->set_device_address(output_addr, false);
243     uint64_t start_time = 0;
244     PROFILER_START(start_time);
245     // PyExecute using the data of user_data instead of address, so don't need to sync data form device./
246     if (IsPrimitiveCNode(input_node_with_index.first, prim::kPrimPyExecute)) {
247       MS_LOG(DEBUG) << "The input node is " << input_node_with_index.first->ToString()
248                     << ", use user data instead of address.";
249       return depended_value;
250     }
251     MS_LOG(DEBUG) << "Get depend value tensor for node:" << node->DebugString() << " input index:" << i
252                   << " input node:" << input_node_with_index.first->DebugString() << " index"
253                   << input_node_with_index.second << " node addr:" << input_node_with_index.first
254                   << " sync for device tensor:" << output_addr;
255     depended_value->data_sync();
256     PROFILER_END(start_time, runtime::ProfilerModule::kKernel, runtime::ProfilerEvent::kKernelInferDataSync,
257                  node->fullname_with_scope(), true);
258     return depended_value;
259   }
260 
261   MS_LOG(EXCEPTION) << "There is no valid data for " << i << " input of " << node->DebugString() << ", "
262                     << node->fullname_with_scope();
263 }
264 
MakeNewAbstractByScalar(const tensor::TensorPtr & depended_value)265 abstract::AbstractBasePtr MakeNewAbstractByScalar(const tensor::TensorPtr &depended_value) {
266   abstract::AbstractBasePtr new_abs;
267   MS_EXCEPTION_IF_NULL(depended_value);
268   MS_EXCEPTION_IF_NULL(depended_value->Dtype());
269   auto type = depended_value->Dtype()->type_id();
270   if (type == kNumberTypeInt32) {
271     auto tensor_data = reinterpret_cast<int32_t *>(depended_value->data_c());
272     MS_EXCEPTION_IF_NULL(tensor_data);
273     new_abs = std::make_shared<abstract::AbstractScalar>(*tensor_data);
274   } else if (type == kNumberTypeInt64) {
275     auto tensor_data = reinterpret_cast<int64_t *>(depended_value->data_c());
276     MS_EXCEPTION_IF_NULL(tensor_data);
277     new_abs = std::make_shared<abstract::AbstractScalar>(*tensor_data);
278   } else if (type == kNumberTypeFloat32) {
279     auto tensor_data = reinterpret_cast<float *>(depended_value->data_c());
280     MS_EXCEPTION_IF_NULL(tensor_data);
281     new_abs = std::make_shared<abstract::AbstractScalar>(*tensor_data);
282   } else if (type == kNumberTypeFloat64) {
283     auto tensor_data = reinterpret_cast<double *>(depended_value->data_c());
284     MS_EXCEPTION_IF_NULL(tensor_data);
285     new_abs = std::make_shared<abstract::AbstractScalar>(*tensor_data);
286   } else if (type == kNumberTypeBool) {
287     auto tensor_data = reinterpret_cast<bool *>(depended_value->data_c());
288     MS_EXCEPTION_IF_NULL(tensor_data);
289     new_abs = std::make_shared<abstract::AbstractScalar>(*tensor_data);
290   } else {
291     MS_LOG(EXCEPTION) << "Unsupported type: " << type;
292   }
293   return new_abs;
294 }
295 
296 template <typename T>
MakeElemsByTensorValue(void * data,size_t size)297 abstract::AbstractBasePtrList MakeElemsByTensorValue(void *data, size_t size) {
298   MS_EXCEPTION_IF_NULL(data);
299   T *tensor_data = static_cast<T *>(data);
300   AbstractBasePtrList elems;
301   for (size_t i = 0; i < size; i++) {
302     auto scalar = std::make_shared<abstract::AbstractScalar>(tensor_data[i]);
303     (void)elems.emplace_back(scalar);
304   }
305   return elems;
306 }
307 
MakeNewAbstractBySequence(const tensor::TensorPtr & depended_value,const abstract::AbstractBasePtr & input_abs)308 abstract::AbstractBasePtr MakeNewAbstractBySequence(const tensor::TensorPtr &depended_value,
309                                                     const abstract::AbstractBasePtr &input_abs) {
310   abstract::AbstractBasePtr new_abs;
311   MS_EXCEPTION_IF_NULL(depended_value);
312   MS_EXCEPTION_IF_NULL(depended_value->Dtype());
313   MS_EXCEPTION_IF_NULL(input_abs);
314   auto type = depended_value->Dtype()->type_id();
315   AbstractBasePtrList elems;
316   switch (type) {
317     case kNumberTypeInt32: {
318       elems = MakeElemsByTensorValue<int32_t>(depended_value->data_c(), depended_value->DataSize());
319       break;
320     }
321     case kNumberTypeInt64: {
322       elems = MakeElemsByTensorValue<int64_t>(depended_value->data_c(), depended_value->DataSize());
323       break;
324     }
325     case kNumberTypeFloat32: {
326       elems = MakeElemsByTensorValue<float>(depended_value->data_c(), depended_value->DataSize());
327       break;
328     }
329     case kNumberTypeFloat64: {
330       elems = MakeElemsByTensorValue<double>(depended_value->data_c(), depended_value->DataSize());
331       break;
332     }
333     case kNumberTypeBool: {
334       elems = MakeElemsByTensorValue<bool>(depended_value->data_c(), depended_value->DataSize());
335       break;
336     }
337     default: {
338       MS_LOG(EXCEPTION) << "Unsupported type: " << type;
339     }
340   }
341   if (input_abs->isa<abstract::AbstractTuple>()) {
342     new_abs = std::make_shared<abstract::AbstractTuple>(elems);
343   } else if (input_abs->isa<abstract::AbstractList>()) {
344     new_abs = std::make_shared<abstract::AbstractList>(elems);
345   } else {
346     MS_LOG(EXCEPTION) << "Unsupported abstract type:" << input_abs->ToString();
347   }
348   MS_EXCEPTION_IF_NULL(new_abs);
349   new_abs->set_value(depended_value);
350   return new_abs;
351 }
352 
MakeNewAbstract(const AnfNodePtr & input,const tensor::TensorPtr & depended_value,const size_t & input_index)353 abstract::AbstractBasePtr MakeNewAbstract(const AnfNodePtr &input, const tensor::TensorPtr &depended_value,
354                                           const size_t &input_index) {
355   MS_EXCEPTION_IF_NULL(input);
356   auto abs = input->abstract();
357   MS_EXCEPTION_IF_NULL(abs);
358   abstract::AbstractBasePtr new_abs;
359   if (abs->isa<abstract::AbstractTensor>()) {
360     new_abs = abs->Clone();
361     MS_EXCEPTION_IF_NULL(new_abs);
362     new_abs->set_value(depended_value);
363   } else if (abs->isa<abstract::AbstractScalar>()) {
364     new_abs = MakeNewAbstractByScalar(depended_value);
365   } else if (AnfAlgo::IsRealSquenceOutput(input)) {
366     new_abs = MakeNewAbstractBySequence(depended_value, abs);
367   } else if (abs->isa<abstract::AbstractSequence>()) {
368     auto abstract_seq = abs->cast<abstract::AbstractSequencePtr>();
369     MS_EXCEPTION_IF_NULL(abstract_seq);
370     MS_EXCEPTION_IF_CHECK_FAIL((input_index < abstract_seq->elements().size()), "Index is out of range.");
371     new_abs = abstract_seq->elements()[input_index]->Clone();
372     MS_EXCEPTION_IF_NULL(new_abs);
373     new_abs->set_value(depended_value);
374   } else {
375     MS_LOG(EXCEPTION) << "Unsupported abstract type:" << abs->ToString();
376   }
377   // Set user data for PyExecute infer.
378   if (input->has_user_data<kernel::PyExecuteOutputUserData>()) {
379     const auto &output_data = input->user_data<kernel::PyExecuteOutputUserData>();
380     MS_EXCEPTION_IF_NULL(new_abs);
381     new_abs->set_user_data<kernel::PyExecuteOutputUserData>(output_data);
382   }
383   auto depend_addr = depended_value->device_address();
384   if (depend_addr != nullptr) {
385     MS_LOG(DEBUG) << "Input node : " << input->DebugString() << ",use user_data instead of device address";
386     auto user_data = depend_addr->user_data();
387     if (user_data != nullptr) {
388       new_abs->set_user_data<kernel::PyExecuteOutputUserData>(
389         user_data->get<kernel::PyExecuteOutputUserData>(kernel::PyExecuteOutputUserData::key));
390     }
391   }
392   return new_abs;
393 }
394 
InferShapeForPrimitive(const CNodePtr & cnode,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list,bool has_py_execute_data)395 void InferShapeForPrimitive(const CNodePtr &cnode, const PrimitivePtr &primitive,
396                             const AbstractBasePtrList &args_spec_list, bool has_py_execute_data) {
397   MS_EXCEPTION_IF_NULL(cnode);
398   if (!has_py_execute_data && !IsPrimitiveCNode(cnode, prim::kPrimPyExecute)) {
399     // Pynative mode is rely on the origin abstract of cnode, so cannot modify the abstract inplace, clone from old
400     // abstract instead.
401     opt::CppInferShape(primitive, args_spec_list, cnode);
402   }
403 }
404 
InferShape(const CNodePtr & cnode,std::map<uint32_t,tensor::TensorPtr> * depend_tensor_map,void * args)405 void InferShape(const CNodePtr &cnode, std::map<uint32_t, tensor::TensorPtr> *depend_tensor_map, void *args) {
406   MS_EXCEPTION_IF_NULL(cnode);
407   MS_EXCEPTION_IF_NULL(depend_tensor_map);
408   MS_LOG(DEBUG) << "InferShape start, node:" << cnode->fullname_with_scope();
409   std::set<int64_t> depend_list = abstract::GetValueDependArgIndices(cnode);
410 
411   depend_tensor_map->clear();
412   auto &inputs = cnode->inputs();
413   if (inputs.empty()) {
414     MS_LOG(EXCEPTION) << "Invalid inputs.";
415   }
416   auto context = MsContext::GetInstance();
417   MS_EXCEPTION_IF_NULL(context);
418   AbstractBasePtrList args_spec_list;
419   auto input_size = common::AnfAlgo::GetInputTensorNum(cnode);
420   bool skip_nop_node = !context->get_param<bool>(MS_CTX_ENABLE_MINDRT);
421   bool has_py_execute_data = false;
422   kernel::PyExecuteOutputUserDataPtr list_user_data = nullptr;
423   std::vector<size_t> list_start_index;
424   for (size_t i = 0; i < input_size; i++) {
425     auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(cnode, i, false);
426     auto real_input = input_node_with_index.first;
427     auto real_input_index = input_node_with_index.second;
428 
429     MS_EXCEPTION_IF_NULL(real_input);
430     if (skip_nop_node) {
431       InferShapeForNopNode(real_input);
432     }
433 
434     if (depend_list.find(i) != depend_list.end()) {
435       auto depended_value = GetDependValueTensor(cnode, i, input_node_with_index, skip_nop_node, args);
436       auto ret2 = depend_tensor_map->try_emplace(i, depended_value);
437       if (!ret2.second) {
438         MS_LOG(EXCEPTION) << "Insert map failed.";
439       }
440 
441       auto updated_abs = MakeNewAbstract(real_input, depended_value, real_input_index);
442       MS_EXCEPTION_IF_NULL(updated_abs);
443       MS_EXCEPTION_IF_NULL(real_input);
444       MS_EXCEPTION_IF_NULL(real_input->abstract());
445       if (updated_abs->has_user_data<kernel::PyExecuteOutputUserData>()) {
446         has_py_execute_data = true;
447         if (IsPrimitiveCNode(real_input, prim::kPrimPyExecute) &&
448             real_input->abstract()->isa<abstract::AbstractSequence>()) {
449           auto updated_abs_user_data = updated_abs->user_data<kernel::PyExecuteOutputUserData>();
450           if (list_user_data == nullptr || list_user_data != updated_abs_user_data) {
451             list_start_index.push_back(i);
452             list_user_data = updated_abs_user_data;
453           }
454         }
455       }
456       (void)args_spec_list.emplace_back(updated_abs);
457     } else {
458       auto abs = real_input->abstract();
459       MS_EXCEPTION_IF_NULL(abs);
460       MS_LOG(DEBUG) << "Real input node:" << real_input->DebugString() << " abs:" << abs->ToString()
461                     << " index:" << real_input_index;
462       if (abs->isa<abstract::AbstractSequence>() && !AnfAlgo::IsRealSquenceOutput(real_input)) {
463         auto abs_seq = abs->cast<abstract::AbstractSequencePtr>();
464         MS_EXCEPTION_IF_NULL(abs_seq);
465         MS_EXCEPTION_IF_CHECK_FAIL((real_input_index < abs_seq->elements().size()), "Index is out of range.");
466         auto abs_index = abs_seq->elements()[real_input_index];
467         (void)args_spec_list.emplace_back(abs_index);
468       } else {
469         (void)args_spec_list.emplace_back(abs);
470       }
471     }
472   }
473   MS_EXCEPTION_IF_NULL(inputs[0]);
474   if (auto primitive = GetValueNode<PrimitivePtr>(inputs[0])) {
475     MS_EXCEPTION_IF_NULL(primitive);
476     (void)primitive->AddAttr(kAttrListStartIndex, MakeValue(list_start_index));
477     InferShapeForPrimitive(cnode, primitive, args_spec_list, has_py_execute_data);
478   } else {
479     MS_LOG(EXCEPTION) << "The first input of the cnode should be either a primitive or a function graph, but get: "
480                       << inputs[0]->fullname_with_scope();
481   }
482   MS_LOG(DEBUG) << "InferShape end, node:" << cnode->fullname_with_scope();
483 }
484 
IsCpuKernelMod(kernel::KernelModType kernel_mod_type)485 inline bool IsCpuKernelMod(kernel::KernelModType kernel_mod_type) {
486   return kernel_mod_type == kernel::KernelModType::NativeCpuKernelMod;
487 }
488 }  // namespace
489 
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)490 BaseShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
491   MS_EXCEPTION_IF_NULL(primitive);
492   if (primitive->HasAttr(kAttrInferShapeFunctor)) {
493     auto functor = primitive->GetAttr(kAttrInferShapeFunctor)->cast<InferShapeFunctorPtr>();
494     MS_EXCEPTION_IF_NULL(functor);
495     return functor->InferShape(input_args);
496   }
497   const auto &op_name = primitive->name();
498   runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kKernel, runtime::ProfilerEvent::kKernelInferInner,
499                                      op_name, true);
500   auto shape_optional = abstract::InferShapeByFuncImpl(primitive, input_args, false);
501   if (shape_optional.has_value()) {
502     return shape_optional.value();
503   }
504 
505   // The old register map for InferShape will be deleted in the future.
506   auto infer_impl = abstract::GetBackendPrimitiveInferImpl(primitive);
507   if (infer_impl.has_value()) {
508     auto infer = infer_impl.value();
509     if (infer.IsImplInferShapeAndType()) {
510       return infer.InferShape(primitive, input_args);
511     }
512   }
513   MS_LOG(EXCEPTION) << "The InferShape function of [" << op_name << "] is not defined.";
514 }
515 
UpdateKernelTensorShape(const BaseShapePtr & base_shape,const std::vector<kernel::KernelTensor * > & output_kernel_tensors)516 void UpdateKernelTensorShape(const BaseShapePtr &base_shape,
517                              const std::vector<kernel::KernelTensor *> &output_kernel_tensors) {
518   MS_EXCEPTION_IF_NULL(base_shape);
519   size_t output_num = output_kernel_tensors.size();
520   if (output_num > 1) {
521     auto sequence_shape = base_shape->cast<abstract::SequenceShapePtr>();
522     MS_EXCEPTION_IF_NULL(sequence_shape);
523     const auto &shapes = sequence_shape->shape();
524     if (shapes.size() != output_num) {
525       MS_LOG(EXCEPTION) << "Invalid SequenceShape, expected elements number: " << output_num
526                         << ", but got: " << shapes.size();
527     }
528     for (size_t i = 0; i < output_num; i++) {
529       const auto &kernel_tensor = output_kernel_tensors[i];
530       MS_EXCEPTION_IF_NULL(kernel_tensor);
531       kernel_tensor->SetShape(shapes[i]);
532     }
533   } else if (output_num == 1) {
534     const auto &kernel_tensor = output_kernel_tensors[0];
535     MS_EXCEPTION_IF_NULL(kernel_tensor);
536     auto sequence_shape = base_shape->cast<abstract::SequenceShapePtr>();
537     if ((kernel_tensor->type_id() != kObjectTypeTuple && kernel_tensor->type_id() != kObjectTypeList) &&
538         sequence_shape != nullptr) {
539       // For the operator prototype whose output is of type Tuple, the back-end operator is expanded as Tensors, and for
540       // single-output scenarios, the InferShape result is TupleShape, and the back-end needs to expand it to
541       // TensorShape. For example, the output of the split operator is only a Tensor scene.
542       const auto &shapes = sequence_shape->shape();
543       if (shapes.size() != 1) {
544         MS_LOG(EXCEPTION) << "Invalid SequenceShape, expected elements number: " << 1 << ", but got: " << shapes.size();
545       }
546 
547       kernel_tensor->SetShape(shapes[0]);
548     } else {
549       kernel_tensor->SetShape(base_shape);
550     }
551   }
552 }
553 
InferShapeAndType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)554 abstract::AbstractBasePtr InferShapeAndType(const PrimitivePtr &primitive,
555                                             const std::vector<AbstractBasePtr> &input_args) {
556   MS_EXCEPTION_IF_NULL(primitive);
557   const auto &op_name = primitive->name();
558   auto infer_impl = abstract::GetBackendPrimitiveInferImpl(primitive);
559   if (infer_impl.has_value()) {
560     auto infer = infer_impl.value();
561     if (infer.IsImplInferShapeAndType()) {
562       return infer.InferShapeAndType(nullptr, primitive, input_args);
563     }
564   }
565   MS_LOG(EXCEPTION) << "The InferShape function of [" << op_name << "] is not defined.";
566 }
567 
UpdateKernelTensorType(const TypePtr & type,const std::vector<kernel::KernelTensor * > & output_kernel_tensors)568 void UpdateKernelTensorType(const TypePtr &type, const std::vector<kernel::KernelTensor *> &output_kernel_tensors) {
569   MS_EXCEPTION_IF_NULL(type);
570   if (output_kernel_tensors.size() != 1) {
571     MS_LOG(EXCEPTION) << "Invalid output size:" << output_kernel_tensors.size();
572   }
573 
574   const auto &kernel_tensor = output_kernel_tensors[0];
575   MS_EXCEPTION_IF_NULL(kernel_tensor);
576   kernel_tensor->SetType(type);
577 }
578 
IsRealCNode(const BaseRef & n)579 bool IsRealCNode(const BaseRef &n) {
580   if (utils::isa<CNodePtr>(n)) {
581     CNodePtr cnode = utils::cast<CNodePtr>(n);
582     return AnfUtils::IsRealKernel(cnode);
583   }
584   return false;
585 }
586 
GenInferNode(const AnfNodePtr & node)587 AnfNodePtr GenInferNode(const AnfNodePtr &node) {
588   MS_EXCEPTION_IF_NULL(node);
589   auto cnode = node->cast<CNodePtr>();
590   MS_EXCEPTION_IF_NULL(cnode);
591   auto infer_node = AnfUtils::NewInferActorNode([cnode](void *args) { InferOp(cnode, args); }, cnode);
592   MS_EXCEPTION_IF_NULL(infer_node);
593   infer_node->set_kernel_info(std::make_shared<device::KernelInfo>());
594   return infer_node;
595 }
596 
GenInitNode(const AnfNodePtr & node)597 AnfNodePtr GenInitNode(const AnfNodePtr &node) {
598   MS_EXCEPTION_IF_NULL(node);
599   auto cnode = node->cast<CNodePtr>();
600   MS_EXCEPTION_IF_NULL(cnode);
601 
602   auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
603   MS_EXCEPTION_IF_NULL(kernel_mod);
604   AnfUtils::CustomActorCallback actor_func = [kernel_mod, cnode](void *) {
605     auto inputs = AnfAlgo::GetOrCreateAllInputKernelTensors(cnode);
606     auto outputs = AnfAlgo::GetOrCreateAllOutputKernelTensors(cnode);
607     if (kernel_mod->Resize(inputs, outputs) == static_cast<int>(kernel::KRET_RESIZE_FAILED)) {
608       MS_LOG(EXCEPTION) << "Node " << cnode->fullname_with_scope() << " Resize failed.";
609     }
610   };
611 
612   auto init_node = AnfUtils::NewInitActorNode(actor_func, cnode);
613   MS_EXCEPTION_IF_NULL(init_node);
614   init_node->set_kernel_info(std::make_shared<device::KernelInfo>());
615   return init_node;
616 }
617 
InferOp(const CNodePtr & cnode,void * args)618 void InferOp(const CNodePtr &cnode, void *args) {
619   MS_EXCEPTION_IF_NULL(cnode);
620   auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
621   MS_EXCEPTION_IF_NULL(kernel_mod);
622 
623   kernel::KernelArgs kernel_args;
624   MS_LOG(DEBUG) << "infer shape for node:" << cnode->fullname_with_scope();
625   InferShape(cnode, &kernel_args.depend_tensor_map, args);
626   auto kernel_mod_type = kernel_mod->GetKernelModType();
627   auto update = kernel::AbstractArgsFromCNode(cnode);
628   update.depend_tensor_map = std::move(kernel_args.depend_tensor_map);
629   kernel::SetInputsByDependMap(update.depend_tensor_map, &update.inputs, IsCpuKernelMod(kernel_mod_type));
630   kernel::SetArgsToCNode(cnode, update);
631 }
632 
Instance()633 CustomActorNodeManager &CustomActorNodeManager::Instance() {
634   static CustomActorNodeManager instance{};
635   return instance;
636 }
637 }  // namespace opt::dynamic_shape
638 }  // namespace mindspore
639