• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/pynative/grad/ir/dynamic_shape.h"
18 #include <algorithm>
19 #include "pipeline/pynative/pynative_utils.h"
20 
21 namespace mindspore {
22 namespace pynative {
23 namespace {
24 constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput";
25 constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
26 constexpr size_t kMaxCacheDynamicShapeCellNum = 2;
27 
IsValuePtrEqual(const ValuePtr & v1,const ValuePtr & v2)28 bool IsValuePtrEqual(const ValuePtr &v1, const ValuePtr &v2) {
29   if (v1 == v2) {
30     return true;
31   }
32   if (v1 == nullptr || v2 == nullptr) {
33     return false;
34   }
35   if (v1->isa<tensor::BaseTensor>() && v2->isa<tensor::BaseTensor>()) {
36     return v1->cast<tensor::BaseTensorPtr>()->ValueEqual(*(v2->cast<tensor::BaseTensorPtr>()));
37   }
38   return *v1 == *v2;
39 }
40 
IsDynamicDetectPrimChange(const PrimitivePtr & old_prim,const PrimitivePtr & new_prim)41 bool IsDynamicDetectPrimChange(const PrimitivePtr &old_prim, const PrimitivePtr &new_prim) {
42   if (old_prim == nullptr && new_prim == nullptr) {
43     return false;
44   }
45   // Use kernel graph will add kIsFeatureMapOutput adn kIsFeatureMapOutput attr,
46   // but check must be remove them
47   if (old_prim != nullptr && old_prim->HasAttr(kIsFeatureMapOutput)) {
48     old_prim->EraseAttr(kIsFeatureMapOutput);
49     old_prim->EraseAttr(kIsFeatureMapInputList);
50   }
51   if (new_prim != nullptr && old_prim != nullptr) {
52     return !common::IsEqual(old_prim, new_prim);
53   }
54   return true;
55 }
56 
IsNodeInfoChange(const NodeInfo & old_node_info,const NodeInfo & new_node_info)57 bool IsNodeInfoChange(const NodeInfo &old_node_info, const NodeInfo &new_node_info) {
58   size_t input_size = old_node_info.seq_node.size();
59   if (input_size != new_node_info.seq_node.size()) {
60     MS_LOG(DEBUG) << "Graph is dynamic, input is tuple, but old seq node info size " << input_size
61                   << ", new seq node info size " << new_node_info.seq_node.size();
62     return true;
63   } else {
64     for (size_t i = 0; i < input_size; ++i) {
65       if (IsNodeInfoChange(old_node_info.seq_node[i], new_node_info.seq_node[i])) {
66         return true;
67       }
68     }
69   }
70 
71   if (new_node_info.grad_type == InputType::kParameter &&
72       (old_node_info.grad_type == InputType::kParameter || old_node_info.grad_type == InputType::kConstant)) {
73     MS_EXCEPTION_IF_NULL(new_node_info.value);
74     MS_EXCEPTION_IF_NULL(old_node_info.value);
75     auto new_tensor = new_node_info.value->cast<tensor::BaseTensorPtr>();
76     MS_EXCEPTION_IF_NULL(new_tensor);
77     auto old_tensor = old_node_info.value->cast<tensor::BaseTensorPtr>();
78     MS_EXCEPTION_IF_NULL(old_tensor);
79     if (new_tensor->id() != old_tensor->id()) {
80       MS_LOG(DEBUG) << "Graph is dynamic, new node info value: "
81                     << (new_node_info.value != nullptr ? new_node_info.value->ToString() : "")
82                     << ", grad type: " << new_node_info.grad_type << ", old node info value: "
83                     << (old_node_info.value != nullptr ? old_node_info.value->ToString() : "")
84                     << ", grad type: " << old_node_info.grad_type;
85       return true;
86     }
87     return false;
88   }
89 
90   if (new_node_info.grad_type != old_node_info.grad_type) {
91     MS_LOG(DEBUG) << "Graph is dynamic, new node info grad type: " << new_node_info.grad_type
92                   << ", old node info grad type: " << old_node_info.grad_type;
93     return true;
94   }
95 
96   if (new_node_info.grad_type == InputType::kOpOutput && new_node_info.op_index != old_node_info.op_index) {
97     MS_LOG(DEBUG) << "Graph is dynamic, new node info op_index: " << new_node_info.op_index
98                   << ", old node info op_index: " << old_node_info.op_index;
99     return true;
100   }
101 
102   if (new_node_info.grad_type == InputType::kConstant && !IsValuePtrEqual(new_node_info.value, old_node_info.value)) {
103     MS_LOG(DEBUG) << "Graph is dynamic, new node info value: "
104                   << (new_node_info.value != nullptr ? new_node_info.value->ToString() : "")
105                   << ", grad type: " << new_node_info.grad_type << ", old node info value: "
106                   << (old_node_info.value != nullptr ? old_node_info.value->ToString() : "")
107                   << ", grad type: " << old_node_info.grad_type;
108     return true;
109   }
110 
111   return false;
112 }
113 
IsInputsNodeInfoChange(const std::vector<NodeInfo> & old_inputs_node_info,const std::vector<NodeInfo> & new_inputs_node_info)114 bool IsInputsNodeInfoChange(const std::vector<NodeInfo> &old_inputs_node_info,
115                             const std::vector<NodeInfo> &new_inputs_node_info) {
116   size_t input_size = old_inputs_node_info.size();
117   if (input_size != new_inputs_node_info.size()) {
118     MS_LOG(DEBUG) << "Graph is dynamic, old_inputs size: " << input_size
119                   << "new_inputs size: " << new_inputs_node_info.size();
120     return true;
121   }
122   for (size_t i = 0; i < input_size; ++i) {
123     if (IsNodeInfoChange(old_inputs_node_info[i], new_inputs_node_info[i])) {
124       return true;
125     }
126   }
127   return false;
128 }
129 
GetNodeInfoFromValue(const ValuePtr & input)130 NodeInfo GetNodeInfoFromValue(const ValuePtr &input) {
131   if (input->isa<tensor::BaseTensor>()) {
132     NodeInfo node_info;
133     auto tensor = input->cast<tensor::BaseTensorPtr>();
134     auto auto_meta_data = tensor->auto_grad_meta_data();
135     // Scalar tensor
136     if (auto_meta_data == nullptr) {
137       node_info.grad_type = InputType::kConstant;
138       node_info.value = input;
139       return node_info;
140     }
141 
142     // Tensor
143     node_info.grad_type = auto_meta_data->input_type();
144     node_info.op_index = auto_meta_data->op_index();
145     if (node_info.grad_type == InputType::kConstant || node_info.grad_type == InputType::kParameter) {
146       node_info.value = input;
147     }
148     return node_info;
149   } else if (input->isa<ValueSequence>()) {
150     NodeInfo node_info;
151     const auto &value_sequence = input->cast<ValueSequencePtr>();
152     for (const auto &i : value_sequence->value()) {
153       node_info.seq_node.emplace_back(GetNodeInfoFromValue(i));
154     }
155   } else if (input->isa<stub::StubNode>()) {
156     auto stub_node = input->cast<stub::StubNodePtr>();
157     MS_EXCEPTION_IF_NULL(stub_node);
158     GetNodeInfoFromValue(stub_node->WaitValue());
159   } else {
160     NodeInfo node_info;
161     node_info.grad_type = InputType::kConstant;
162     node_info.value = input;
163     return node_info;
164   }
165   return NodeInfo{};
166 }
167 
168 struct CompareBasedOnAbstract {
IsNodeChangemindspore::pynative::__anonf68ef89a0111::CompareBasedOnAbstract169   static bool IsNodeChange(const ValuePtrList &inputs, const DynamicDetectNodeInfoPtr &old_node,
170                            const DynamicDetectNodeInfoPtr &new_node) {
171     // Compare input abs
172     if (IsDynamicDetectAbsChange(old_node->abs_compare_info.input_abs, new_node->abs_compare_info.input_abs)) {
173       return true;
174     }
175 
176     // Compare out abs
177     if (IsDynamicDetectAbsChange(old_node->abs_compare_info.out_abs, new_node->abs_compare_info.out_abs)) {
178       return true;
179     }
180 
181     // Get input
182     BuildDynamicDetectInputsNodeInfo(new_node, inputs);
183 
184     // Compare input
185     return IsInputsNodeInfoChange(old_node->abs_compare_info.inputs, new_node->abs_compare_info.inputs);
186   }
187 
IsDynamicDetectAbsChangemindspore::pynative::__anonf68ef89a0111::CompareBasedOnAbstract188   static bool IsDynamicDetectAbsChange(const AbstractBasePtr &old_abs, const AbstractBasePtr &new_abs) {
189     if (old_abs == new_abs) {
190       return false;
191     }
192     if (old_abs == nullptr || new_abs == nullptr) {
193       MS_LOG(DEBUG) << "Graph is dynamic, old_abs is different with new_abs";
194       return true;
195     }
196     if (!common::IsEqual(old_abs->BuildType(), new_abs->BuildType()) ||
197         !common::IsEqual(old_abs->BuildShape(), new_abs->BuildShape())) {
198       MS_LOG(DEBUG) << "Graph is dynamic, old_abs is different with new_abs, old abs: " << old_abs->ToString()
199                     << ", new abs: " << new_abs->ToString();
200       return true;
201     }
202     return false;
203   }
204 
IsDynamicDetectAbsChangemindspore::pynative::__anonf68ef89a0111::CompareBasedOnAbstract205   static bool IsDynamicDetectAbsChange(const abstract::AbstractBasePtrList &node_abs,
206                                        const abstract::AbstractBasePtrList &old_node_abs) {
207     if (node_abs.size() != old_node_abs.size()) {
208       MS_LOG(DEBUG) << "Graph is dynamic, node_abs size: " << node_abs.size()
209                     << ", old_node_abs size: " << old_node_abs.size();
210       return true;
211     }
212     for (size_t i = 0; i < node_abs.size(); ++i) {
213       if (IsDynamicDetectAbsChange(node_abs[i], old_node_abs[i])) {
214         return true;
215       }
216     }
217     return false;
218   }
219 
BuildDynamicDetectInputsNodeInfomindspore::pynative::__anonf68ef89a0111::CompareBasedOnAbstract220   static void BuildDynamicDetectInputsNodeInfo(const DynamicDetectNodeInfoPtr &node, const ValuePtrList &inputs) {
221     std::transform(inputs.begin(), inputs.end(), std::back_inserter(node->abs_compare_info.inputs),
222                    [](const auto &item) { return GetNodeInfoFromValue(item); });
223   }
224 };
225 
226 struct CompareBasedOnValueSimpleInfo {
IsNodeChangemindspore::pynative::__anonf68ef89a0111::CompareBasedOnValueSimpleInfo227   static bool IsNodeChange(const ValuePtrList &inputs, const DynamicDetectNodeInfoPtr &old_node,
228                            const DynamicDetectNodeInfoPtr &new_node) {
229     BuildInputsValueSimpleInfo(new_node, inputs);
230     return IsInputsChange(old_node->value_compare_info, new_node->value_compare_info);
231   }
232 
BuildInputsValueSimpleInfomindspore::pynative::__anonf68ef89a0111::CompareBasedOnValueSimpleInfo233   static void BuildInputsValueSimpleInfo(const DynamicDetectNodeInfoPtr &node, const ValuePtrList &inputs) {
234     size_t input_size = inputs.size();
235     node->value_compare_info.input_value_simple_info.size_ = input_size;
236     node->value_compare_info.input_value_simple_info.shape_vector_.reserve(input_size);
237     node->value_compare_info.input_value_simple_info.dtype_vector_.reserve(input_size);
238     node->value_compare_info.input_value_simple_info.object_type_vector_.reserve(input_size);
239     for (const auto &input : inputs) {
240       node->value_compare_info.inputs.emplace_back(GetNodeInfoFromValue(input));
241 
242       (void)node->value_compare_info.input_value_simple_info.shape_vector_.emplace_back(
243         PyNativeAlgo::Common::GetShapeFromValue(input));
244       auto [dtype, obj_type] = PyNativeAlgo::Common::GetTypeFromValue(input);
245       (void)node->value_compare_info.input_value_simple_info.dtype_vector_.emplace_back(dtype);
246       (void)node->value_compare_info.input_value_simple_info.object_type_vector_.emplace_back(obj_type);
247     }
248   }
249 
IsInputsChangemindspore::pynative::__anonf68ef89a0111::CompareBasedOnValueSimpleInfo250   static bool IsInputsChange(const ValueCompareInfo &old_value_compare_info,
251                              const ValueCompareInfo &new_value_compare_info) {
252     if (IsInputsNodeInfoChange(old_value_compare_info.inputs, new_value_compare_info.inputs)) {
253       return true;
254     }
255     return IsValueSimpleInfoChange(old_value_compare_info.input_value_simple_info,
256                                    new_value_compare_info.input_value_simple_info);
257   }
258 
259   template <typename T1, typename T2>
IsNotEuqalmindspore::pynative::__anonf68ef89a0111::CompareBasedOnValueSimpleInfo260   static bool IsNotEuqal(const T1 &old_input, const T2 &new_input) {
261     return old_input != new_input;
262   }
263 
264   template <typename T1, typename T2>
IsNotEuqalmindspore::pynative::__anonf68ef89a0111::CompareBasedOnValueSimpleInfo265   static bool IsNotEuqal(const std::shared_ptr<T1> &old_input, const std::shared_ptr<T2> &new_input) {
266     MS_EXCEPTION_IF_NULL(old_input);
267     MS_EXCEPTION_IF_NULL(new_input);
268     return old_input->type_id() != new_input->type_id();
269   }
270 
IsValueSimpleInfoChangemindspore::pynative::__anonf68ef89a0111::CompareBasedOnValueSimpleInfo271   static bool IsValueSimpleInfoChange(const ValueSimpleInfo &old_input_simple_info,
272                                       const ValueSimpleInfo &new_input_simple_info) {
273     if (old_input_simple_info.size_ != new_input_simple_info.size_) {
274       MS_LOG(DEBUG) << "Graph is dynamic, old_input_simple_info size: " << old_input_simple_info.size_
275                     << ", new_input_simple_info size: " << new_input_simple_info.size_;
276       return true;
277     }
278     for (size_t i = 0; i < old_input_simple_info.size_; ++i) {
279       if (IsNotEuqal(old_input_simple_info.shape_vector_[i], new_input_simple_info.shape_vector_[i]) ||
280           IsNotEuqal(old_input_simple_info.dtype_vector_[i], new_input_simple_info.dtype_vector_[i]) ||
281           IsNotEuqal(old_input_simple_info.object_type_vector_[i], new_input_simple_info.object_type_vector_[i])) {
282         MS_LOG(DEBUG) << "Graph is dynamic, old input simple info: " << ValueSimpleInfoToString(old_input_simple_info)
283                       << ", new input simple info: " << ValueSimpleInfoToString(new_input_simple_info);
284         return true;
285       }
286     }
287     return false;
288   }
289 };
290 
UpdateAbsCache(const std::string & arg_id,const ValuePtr & v,const abstract::BaseShapePtr & base_shape,const abstract::AbstractBasePtr & abs,size_t index)291 void UpdateAbsCache(const std::string &arg_id, const ValuePtr &v, const abstract::BaseShapePtr &base_shape,
292                     const abstract::AbstractBasePtr &abs, size_t index) {
293   auto update_abs = abs;
294   if (update_abs == nullptr) {
295     MS_EXCEPTION_IF_NULL(v);
296     auto input_tensor = v->cast<tensor::BaseTensorPtr>();
297     // Just tensor work in unknown shape
298     if (input_tensor == nullptr) {
299       return;
300     }
301     MS_EXCEPTION_IF_NULL(base_shape);
302     update_abs = std::make_shared<abstract::AbstractTensor>(input_tensor->Dtype(), base_shape);
303   }
304   MS_LOG(DEBUG) << "Set arg " << index << ", id " << arg_id << ", to dynamic abs: " << update_abs->ToString();
305   const auto &infer = PyNativeAlgo::Common::GetPyNativeExecutor()->forward_executor()->infer_operation();
306   infer->UpdateNodeAbsCacheById(arg_id, update_abs);
307 }
308 
GetUnknownShape(const ShapeVector & cur_shape,const ShapeVector & pre_top_cell_shape,ShapeVector * new_shape)309 bool GetUnknownShape(const ShapeVector &cur_shape, const ShapeVector &pre_top_cell_shape, ShapeVector *new_shape) {
310   // Dynamic rank
311   if (cur_shape.size() != pre_top_cell_shape.size()) {
312     MS_LOG(INFO) << "Cur shape size " << cur_shape.size() << " is not equal to top cell arg shape size "
313                  << pre_top_cell_shape.size();
314     (void)new_shape->emplace_back(abstract::Shape::kShapeRankAny);
315     return true;
316   }
317   // Dynamic shape
318   for (size_t j = 0; j < cur_shape.size(); ++j) {
319     if (cur_shape[j] == pre_top_cell_shape[j]) {
320       (void)new_shape->emplace_back(cur_shape[j]);
321     } else {
322       (void)new_shape->emplace_back(abstract::Shape::kShapeDimAny);
323     }
324   }
325   // All shape can not be actual, which indicates static shape.
326   if (!IsDynamicShape(*new_shape)) {
327     MS_LOG(DEBUG) << "All shape are actual, is static shape. Cur shape " << cur_shape << ", elem shape "
328                   << pre_top_cell_shape << ", and new shape is " << new_shape;
329     return false;
330   }
331   return true;
332 }
333 
IsMatch(const ShapeVector & cur_shape,const ShapeVector & pre_top_cell_shape)334 bool IsMatch(const ShapeVector &cur_shape, const ShapeVector &pre_top_cell_shape) {
335   if (cur_shape.size() != pre_top_cell_shape.size() && !pre_top_cell_shape.empty() &&
336       pre_top_cell_shape[kIndex0] != abstract::Shape::kShapeRankAny) {
337     MS_LOG(DEBUG) << "Cur shape size " << cur_shape.size() << " is not equal to pre top cell arg shape size "
338                   << pre_top_cell_shape.size();
339     return false;
340   }
341   // Dynamic rank or dynamic shape
342   for (size_t i = 0; i < cur_shape.size(); ++i) {
343     if (cur_shape[i] != pre_top_cell_shape[i] && pre_top_cell_shape[i] != abstract::Shape::kShapeDimAny) {
344       MS_LOG(DEBUG) << "Cur shape " << cur_shape[i] << " can not match pre top cell shape " << pre_top_cell_shape[i];
345       return false;
346     }
347   }
348   return true;
349 }
350 }  // namespace
351 
GetDynamicInput(const py::object & actual_input)352 py::object DynamicShape::GetDynamicInput(const py::object &actual_input) {
353   if (py::isinstance<py::tuple>(actual_input)) {
354     auto tuple_actual_args = py::cast<py::tuple>(actual_input);
355     size_t args_size = tuple_actual_args.size();
356     py::tuple dyn_shape_args = py::tuple(args_size);
357     for (size_t i = 0; i < args_size; ++i) {
358       dyn_shape_args[i] = GetDynamicInput(tuple_actual_args[i]);
359     }
360     return dyn_shape_args;
361   } else if (py::isinstance<py::list>(actual_input)) {
362     auto list_actual_args = py::cast<py::list>(actual_input);
363     size_t args_size = list_actual_args.size();
364     py::list dyn_shape_args;
365     for (size_t i = 0; i < args_size; ++i) {
366       dyn_shape_args.append(GetDynamicInput(list_actual_args[i]));
367     }
368     return dyn_shape_args;
369   } else if (py::isinstance<tensor::BaseTensor>(actual_input)) {
370     const auto &infer = PyNativeAlgo::Common::GetPyNativeExecutor()->forward_executor()->infer_operation();
371     auto tensor_ptr = py::cast<tensor::BaseTensorPtr>(actual_input);
372     MS_EXCEPTION_IF_NULL(tensor_ptr);
373     auto dyn_compile_tensor = std::make_shared<tensor::BaseTensor>(tensor_ptr->data_type(), tensor_ptr->shape_c());
374     const auto &abs = infer->GetNodeAbsById(PyNativeAlgo::PyParser::GetIdByPyObj(actual_input));
375     if (abs != nullptr) {
376       auto base_shape = abs->BuildShape();
377       MS_EXCEPTION_IF_NULL(base_shape);
378       if (base_shape->IsDynamic()) {
379         dyn_compile_tensor->set_base_shape(base_shape);
380       }
381     }
382     return PyNativeAlgo::DataConvert::ValueToPyObj(dyn_compile_tensor);
383   }
384   return actual_input;
385 }
386 
SaveUnknownShapeAbsFromJit(const ValuePtr & v,const AbstractBasePtr & abs,size_t index)387 void DynamicShape::SaveUnknownShapeAbsFromJit(const ValuePtr &v, const AbstractBasePtr &abs, size_t index) {
388   MS_EXCEPTION_IF_NULL(v);
389   MS_EXCEPTION_IF_NULL(abs);
390   if (v->isa<ValueSequence>() && abs->isa<abstract::AbstractSequence>()) {
391     const auto &v_seq = v->cast<ValueSequencePtr>();
392     const auto &abs_seq = abs->cast<abstract::AbstractSequencePtr>();
393     if (v_seq->size() != abs_seq->size()) {
394       MS_LOG(EXCEPTION) << "Obj tuple size " << v_seq->size() << ", but abstract tuple size " << abs_seq->size();
395     }
396     for (size_t i = 0; i < v_seq->size(); ++i) {
397       SaveUnknownShapeAbsFromJit(v_seq->value()[i], abs_seq->elements()[i], index);
398     }
399   } else if (v->isa<tensor::BaseTensor>() && abs->isa<abstract::AbstractTensor>()) {
400     if (abs->BuildShape()->IsDynamic()) {
401       UpdateAbsCache(PyNativeAlgo::Common::GetIdByValue(v), v, nullptr, abs, ++index);
402     }
403   } else {
404     MS_LOG(EXCEPTION) << "Not match: obj " << v->ToString() << " and abs " << abs->ToString();
405   }
406 }
407 
CheckNodeDynamic(const TopCellInfoPtr & top_cell,const ValuePtrList & inputs,const DynamicDetectNodeInfoPtr & node)408 bool NodeDynamicDetect::CheckNodeDynamic(const TopCellInfoPtr &top_cell, const ValuePtrList &inputs,
409                                          const DynamicDetectNodeInfoPtr &node) {
410   std::unique_lock<std::mutex> lock(async_mutex_);
411   MS_EXCEPTION_IF_NULL(top_cell);
412   if (top_cell->use_dynamic_shape_process()) {
413     top_cell->IncreaseOpIndex();
414     return true;
415   }
416 
417   const size_t node_idx = top_cell->op_index();
418   bool node_is_dynamic = false;
419   bool use_dynamic_shape_process =
420     top_cell->has_bprop_cut_op() || (node_is_dynamic = IsNodeDynamic(top_cell, inputs, node, node_idx)) == true;
421   top_cell->IncreaseOpIndex();
422   if (use_dynamic_shape_process) {
423     MS_LOG(INFO) << "Set use_dynamic_shape_process: " << use_dynamic_shape_process;
424     top_cell->set_use_dynamic_shape_process(use_dynamic_shape_process);
425     py::gil_scoped_acquire gil_acquire;
426     (void)cell_id_with_dynamic_detect_nodes_.erase(top_cell->obj_id_with_grad_order());
427   }
428   if (node_is_dynamic) {
429     auto context = MsContext::GetInstance();
430     MS_EXCEPTION_IF_NULL(context);
431     if (context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE)) {
432       MS_LOG(WARNING) << "Detect dynamic shape or dynamic graph structure, the python stack is: ";
433       py::gil_scoped_acquire acquire_gil;
434       py::exec(R"(
435                   import traceback
436                   traceback.print_stack()
437                   )");
438     }
439   }
440   return use_dynamic_shape_process;
441 }
442 
IsNodeDynamic(const TopCellInfoPtr & top_cell,const ValuePtrList & inputs,const DynamicDetectNodeInfoPtr & node,size_t node_idx)443 bool NodeDynamicDetect::IsNodeDynamic(const TopCellInfoPtr &top_cell, const ValuePtrList &inputs,
444                                       const DynamicDetectNodeInfoPtr &node, size_t node_idx) {
445   MS_EXCEPTION_IF_NULL(node);
446   if (top_cell->is_need_save_dynamic_detect_nodes()) {
447     SaveDynamicDetectNodeInfoInFirstTime(top_cell, inputs, node, node_idx);
448     // The net is regarded as a static net by default in the first time.
449     return false;
450   }
451 
452   MS_LOG(DEBUG) << "Check node " << (node->op_prim != nullptr ? node->op_prim->name() : "") << " node_idx: " << node_idx
453                 << ", is_jit_node: " << node->is_graph_node << ", graph_phase: " << node->graph_phase
454                 << ", obj_id_with_grad_order: " << top_cell->obj_id_with_grad_order()
455                 << ", cell id: " << top_cell->cell_id();
456   const auto &dynamic_nodes =
457     cell_id_with_dynamic_detect_nodes_[top_cell->obj_id_with_grad_order()][top_cell->cell_id()];
458   if (node_idx >= dynamic_nodes.size()) {
459     MS_LOG(DEBUG) << "Old dynamic_nodes size: " << dynamic_nodes.size() << ", cur node_idx is: " << node_idx
460                   << ", graph is dynamic.";
461     return true;
462   }
463 
464   // 1.Detect jit phase
465   const DynamicDetectNodeInfoPtr &old_node_info = dynamic_nodes[node_idx];
466   if (node->is_graph_node) {
467     if (!old_node_info->is_graph_node || node->graph_phase != old_node_info->graph_phase) {
468       MS_LOG(DEBUG) << "Graph is dynamic, old is_graph_node: " << old_node_info->is_graph_node
469                     << ", new is_graph_node: " << node->is_graph_node << ", old graph_phase "
470                     << old_node_info->is_graph_node << ", new graph_phase: " << node->graph_phase;
471       return true;
472     }
473     return false;
474   }
475 
476   // 2.Detect prim
477   if (IsDynamicDetectPrimChange(old_node_info->op_prim, node->op_prim)) {
478     MS_LOG(DEBUG) << "Graph is dynamic, old node prim: "
479                   << (old_node_info->op_prim != nullptr
480                         ? old_node_info->op_prim->name() + ", attr: " + old_node_info->op_prim->GetAttrsText()
481                         : "")
482                   << " new node prim: "
483                   << (node->op_prim != nullptr ? node->op_prim->name() + ", attr: " + node->op_prim->GetAttrsText()
484                                                : "")
485                   << " node_idx: " << node_idx;
486     return true;
487   }
488 
489   // 3.Detect inputs
490   if (node->is_value_compare) {
491     return CompareBasedOnValueSimpleInfo::IsNodeChange(inputs, old_node_info, node);
492   } else {
493     return CompareBasedOnAbstract::IsNodeChange(inputs, old_node_info, node);
494   }
495 }
496 
SaveDynamicDetectNodeInfoInFirstTime(const TopCellInfoPtr & top_cell,const ValuePtrList & inputs,const DynamicDetectNodeInfoPtr & node,size_t node_idx)497 void NodeDynamicDetect::SaveDynamicDetectNodeInfoInFirstTime(const TopCellInfoPtr &top_cell, const ValuePtrList &inputs,
498                                                              const DynamicDetectNodeInfoPtr &node, size_t node_idx) {
499   MS_EXCEPTION_IF_NULL(node);
500   if (node->is_value_compare) {
501     CompareBasedOnValueSimpleInfo::BuildInputsValueSimpleInfo(node, inputs);
502   } else {
503     CompareBasedOnAbstract::BuildDynamicDetectInputsNodeInfo(node, inputs);
504   }
505   (void)cell_id_with_dynamic_detect_nodes_[top_cell->obj_id_with_grad_order()][top_cell->cell_id()].emplace_back(node);
506   MS_LOG(DEBUG) << "Save node " << (node->op_prim != nullptr ? node->op_prim->name() : "")
507                 << " firstly, node_idx: " << node_idx << ", is_jit_node: " << node->is_graph_node
508                 << ", graph_phase: " << node->graph_phase
509                 << ", obj_id_with_grad_order: " << top_cell->obj_id_with_grad_order()
510                 << ", cell id: " << top_cell->cell_id();
511 }
512 
IsNeedSaveDynamicDetectNodes(const TopCellInfoPtr & top_cell,bool use_dynamic_shape_process)513 bool NodeDynamicDetect::IsNeedSaveDynamicDetectNodes(const TopCellInfoPtr &top_cell, bool use_dynamic_shape_process) {
514   if (use_dynamic_shape_process) {
515     // top cell is already dynamic shape, no need save nodes.
516     return false;
517   }
518   MS_EXCEPTION_IF_NULL(top_cell);
519   auto cell_iter = cell_id_with_dynamic_detect_nodes_.find(top_cell->obj_id_with_grad_order());
520   if (cell_iter == cell_id_with_dynamic_detect_nodes_.end()) {
521     // Cell is not found in cell_id_with_dynamic_detect_nodes_, need save nodes first.
522     return true;
523   }
524 
525   const auto &cell_infos = cell_iter->second;
526   if (cell_infos.size() == 1) {
527     // top_cell->cell_id() is cell id with inputs shape, if cell id in cell_id_with_dynamic_detect_nodes_
528     // id same with top_cell->cell_id(), no need save nodes.
529     return cell_infos.begin()->first != top_cell->cell_id();
530   } else if (cell_infos.size() == kMaxCacheDynamicShapeCellNum) {
531     auto cell_infos_iter = cell_infos.find(top_cell->cell_id());
532     if (cell_infos_iter == cell_infos.end()) {
533       // cell_id_with_dynamic_detect_nodes_ has two cell id already, current cell is is different
534       // with them. So set_use_dynamic_shape_process for top cell.
535       top_cell->set_use_dynamic_shape_process(true);
536       (void)cell_id_with_dynamic_detect_nodes_.erase(top_cell->obj_id_with_grad_order());
537       MS_LOG(INFO) << "Set use_dynamic_shape_process: " << use_dynamic_shape_process << ", already cached "
538                    << cell_infos.size() << " top cell, cur top cell shape is different: " << top_cell->cell_id();
539     }
540   } else {
541     MS_LOG(EXCEPTION) << "cell_info.size(): " << cell_infos.size() << " is invalid";
542   }
543   return false;
544 }
545 
SetDynamicInput(const py::object & obj,const py::args & args)546 void TopCellUnknownShapeDetect::SetDynamicInput(const py::object &obj, const py::args &args) {
547   const auto &obj_id = PyNativeAlgo::PyParser::GetIdByPyObj(obj);
548   // After first step, set inputs no need work again. Because the top cell of first step is already unknown shape and
549   // follow step will keep unknown shape always, special input_signature
550   if (obj_with_by_inputs_.find(obj_id) != obj_with_by_inputs_.end()) {
551     MS_LOG(DEBUG) << "Obj " << obj_id << " has done set inputs before";
552     return;
553   }
554   auto &arg_base_shape_vec = obj_id_args_info_by_set_inputs_[obj_id];
555   size_t args_size = args.size();
556   arg_base_shape_vec.reserve(args_size);
557   for (size_t i = 0; i < args_size; ++i) {
558     (void)arg_base_shape_vec.emplace_back(PyNativeAlgo::DataConvert::PyObjToValue(args[i])->ToAbstract()->BuildShape());
559   }
560   TryChangeTopCellToUnknownShape(obj_id, arg_base_shape_vec, false);
561   (void)obj_with_by_inputs_.emplace(obj_id);
562 }
563 
TryChangeTopCellToUnknownShape(const std::string & obj_id,const abstract::BaseShapePtrList & arg_base_shape_vec,bool is_auto_detect)564 void TopCellUnknownShapeDetect::TryChangeTopCellToUnknownShape(const std::string &obj_id,
565                                                                const abstract::BaseShapePtrList &arg_base_shape_vec,
566                                                                bool is_auto_detect) {
567   const auto &grad_executor = PyNativeAlgo::Common::GetPyNativeExecutor()->grad_executor();
568   if (is_auto_detect) {
569     // From auto detect
570     auto &top_cell_list = grad_executor->already_run_top_cell();
571     const auto it = std::find_if(top_cell_list.begin(), top_cell_list.end(), [&obj_id](const auto &elem) {
572       return elem.second->input_args_info() != nullptr && elem.second->input_args_info()->obj_id == obj_id;
573     });
574     if (it != top_cell_list.end()) {
575       // Pre top cell is already unknown shape, check current top cell can match it
576       if (it->second->is_unknown_shape() && CanFindMatchedUnknownShapeTopCell(it->second, arg_base_shape_vec)) {
577         MS_LOG(DEBUG) << "Pre top cell has already been unknown shape and can match current top cell";
578         ChangeTopCellToUnknownShape(grad_executor->top_cell(), it->second->input_args_info()->input_arg_base_shape_vec);
579         return;
580       }
581       // If not match before, compare shape and change current top cell do unknown shape
582       if (SetTopCellUnknownShape(grad_executor->top_cell(), it->second, arg_base_shape_vec)) {
583         (void)top_cell_list.erase(it);
584         return;
585       }
586     } else {
587       // Set inputs, first step top cell working here
588       const auto item = obj_id_args_info_by_set_inputs_.find(grad_executor->top_cell()->input_args_info()->obj_id);
589       if (item != obj_id_args_info_by_set_inputs_.end()) {
590         const auto &input_args_info = grad_executor->top_cell()->input_args_info();
591         UpdateUnknownShapeAbsCache(input_args_info->input_arg_id_vec, input_args_info->input_arg_value_vec,
592                                    item->second);
593         (void)obj_id_args_info_by_set_inputs_.erase(item);
594         return;
595       }
596       // C1.set_inputs, run C1(x); C2 is top cell, and run C2(x).
597       if (std::any_of(arg_base_shape_vec.begin(), arg_base_shape_vec.end(),
598                       [](const abstract::BaseShapePtr &base_shape) { return base_shape->IsDynamic(); })) {
599         MS_LOG(DEBUG) << "Top cell is unknown shape now";
600         grad_executor->top_cell()->set_is_unknown_shape(true);
601       }
602     }
603   } else {
604     // From set inputs. Has not create top cell yet
605     if (grad_executor->TopCellHasNotBeenCreate()) {
606       return;
607     }
608     // Jit, top cell create first, then set inputs run
609     const auto item = obj_id_args_info_by_set_inputs_.find(grad_executor->top_cell()->input_args_info()->obj_id);
610     if (item != obj_id_args_info_by_set_inputs_.end()) {
611       MS_LOG(DEBUG) << "Get jit set inputs";
612       ChangeTopCellToUnknownShape(grad_executor->top_cell(), arg_base_shape_vec);
613       (void)obj_id_args_info_by_set_inputs_.erase(item);
614     }
615   }
616 }
617 
UpdateUnknownShapeAbsCache(const std::vector<string> & input_arg_id_vec,const std::vector<ValuePtr> & input_arg_value_vec,const std::vector<abstract::BaseShapePtr> & args_base_shape)618 void TopCellUnknownShapeDetect::UpdateUnknownShapeAbsCache(const std::vector<string> &input_arg_id_vec,
619                                                            const std::vector<ValuePtr> &input_arg_value_vec,
620                                                            const std::vector<abstract::BaseShapePtr> &args_base_shape) {
621   for (size_t i = 0; i < args_base_shape.size(); i++) {
622     MS_EXCEPTION_IF_NULL(args_base_shape[i]);
623     MS_EXCEPTION_IF_NULL(input_arg_value_vec[i]);
624     if (args_base_shape[i]->IsDynamic()) {
625       if (args_base_shape[i]->isa<abstract::Shape>()) {
626         UpdateAbsCache(input_arg_id_vec[i], input_arg_value_vec[i], args_base_shape[i], nullptr, i);
627       } else if (args_base_shape[i]->isa<abstract::SequenceShape>()) {
628         // Input arg is list or tuple
629         const auto &seq_shape = args_base_shape[i]->cast<abstract::SequenceShapePtr>();
630         const auto &seq_v = input_arg_value_vec[i]->cast<ValueSequencePtr>();
631         MS_EXCEPTION_IF_NULL(seq_v);
632         if (seq_v->size() != seq_shape->size()) {
633           MS_LOG(EXCEPTION) << "Sequence value size " << seq_v->size() << " is not equal to seq shape size "
634                             << seq_shape->size();
635         }
636         std::vector<std::string> id_vec;
637         PyNativeAlgo::Common::SplitString(input_arg_id_vec[i], &id_vec);
638         if (id_vec.size() != seq_shape->size()) {
639           MS_LOG(EXCEPTION) << "Id size " << id_vec.size() << " is not equal to seq shape size " << seq_shape->size();
640         }
641         for (size_t j = 0; j < seq_shape->size(); ++j) {
642           UpdateAbsCache(id_vec[j], seq_v->value()[j], seq_shape->shape()[j], nullptr, i + j);
643         }
644       }
645     }
646   }
647 }
648 
UpdateArgsAbsToUnknownShapeAbs(const py::object & obj,const py::args & args)649 void TopCellUnknownShapeDetect::UpdateArgsAbsToUnknownShapeAbs(const py::object &obj, const py::args &args) {
650   if (obj_id_args_info_by_set_inputs_.empty()) {
651     return;
652   }
653 
654   const auto &grad_executor = PyNativeAlgo::Common::GetPyNativeExecutor()->grad_executor();
655   bool top_cell_has_not_been_create = grad_executor->TopCellHasNotBeenCreate();
656   // Top cell is already unknown shape
657   if (!top_cell_has_not_been_create && grad_executor->top_cell()->is_unknown_shape()) {
658     return;
659   }
660 
661   // Current cell is has no set_inputs
662   const auto &obj_id = PyNativeAlgo::PyParser::GetIdByPyObj(obj);
663   const auto it = obj_id_args_info_by_set_inputs_.find(obj_id);
664   if (it == obj_id_args_info_by_set_inputs_.end()) {
665     return;
666   }
667 
668   // Common cell args id and value not create in ParsePyArgsToInputArgsInfo, need get them now.
669   // Update current cell id cache which maybe used for top cell
670   const auto &args_id_v = PyNativeAlgo::PyParser::GetArgsIdAndValue(args);
671   UpdateUnknownShapeAbsCache(args_id_v.first, args_id_v.second, it->second);
672 
673   // C1.set_inputs, run C1(x); C2 is top cell, and run C2(x).
674   if (top_cell_has_not_been_create) {
675     // Has not create top cell yet
676     (void)obj_id_args_info_by_set_inputs_.erase(it);
677     return;
678   }
679 
680   // C1 is top cell, run C1(x); C2 set_inputs, and run C2(x).
681   UpdatePossibleTopCellToUnknownShape(grad_executor->top_cell(), args_id_v.first, it->second);
682   (void)obj_id_args_info_by_set_inputs_.erase(it);
683 }
684 
UpdatePossibleTopCellToUnknownShape(const TopCellInfoPtr & cur_top_cell,const std::vector<string> & cur_arg_id_vec,const abstract::BaseShapePtrList & cur_args_shape)685 void TopCellUnknownShapeDetect::UpdatePossibleTopCellToUnknownShape(const TopCellInfoPtr &cur_top_cell,
686                                                                     const std::vector<string> &cur_arg_id_vec,
687                                                                     const abstract::BaseShapePtrList &cur_args_shape) {
688   MS_LOG(DEBUG) << "Update possible top cell";
689   auto cur_top_cell_base_shape_vec = cur_top_cell->input_args_info()->input_arg_base_shape_vec;
690   const auto &cur_top_cell_id_vec = cur_top_cell->input_args_info()->input_arg_id_vec;
691   bool need_change_top_cell_info = false;
692   // Check top cell args id is the same with current set inputs cell. If dynamic shape, update top cell to unknown shape
693   for (size_t i = 0; i < cur_arg_id_vec.size(); ++i) {
694     auto it = std::find(cur_top_cell_id_vec.begin(), cur_top_cell_id_vec.end(), cur_arg_id_vec[i]);
695     if (it != cur_top_cell_id_vec.end() && cur_args_shape[i]->IsDynamic()) {
696       auto id_index = it - cur_top_cell_id_vec.begin();
697       cur_top_cell_base_shape_vec[id_index] = cur_args_shape[i];
698       need_change_top_cell_info = true;
699     }
700   }
701   // Change current top cell info
702   if (need_change_top_cell_info) {
703     cur_top_cell->ChangeTopCellInfo(cur_top_cell_base_shape_vec);
704   }
705 }
706 
CanFindMatchedUnknownShapeTopCell(const TopCellInfoPtr & pre_top_cell,const abstract::BaseShapePtrList & cur_args_shape)707 bool TopCellUnknownShapeDetect::CanFindMatchedUnknownShapeTopCell(const TopCellInfoPtr &pre_top_cell,
708                                                                   const abstract::BaseShapePtrList &cur_args_shape) {
709   for (size_t i = 0; i < cur_args_shape.size(); ++i) {
710     const auto &cur_shape = cur_args_shape[i];
711     const auto &pre_top_cell_shape = pre_top_cell->input_args_info()->input_arg_base_shape_vec[i];
712     MS_EXCEPTION_IF_NULL(cur_shape);
713     MS_EXCEPTION_IF_NULL(pre_top_cell_shape);
714     if (cur_shape->isa<abstract::Shape>() && pre_top_cell_shape->isa<abstract::Shape>()) {
715       if (!IsMatch(cur_shape->cast<abstract::ShapePtr>()->shape(),
716                    pre_top_cell_shape->cast<abstract::ShapePtr>()->shape())) {
717         return false;
718       }
719     } else if (cur_shape->isa<abstract::SequenceShape>() && pre_top_cell_shape->isa<abstract::SequenceShape>()) {
720       // Input arg is list or tuple
721       const auto &cur_shape_seq = cur_shape->cast<abstract::SequenceShapePtr>();
722       const auto &top_cell_shape_seq = pre_top_cell_shape->cast<abstract::SequenceShapePtr>();
723       size_t cur_shape_size = cur_shape_seq->size();
724       if (cur_shape_size != top_cell_shape_seq->size()) {
725         MS_LOG(DEBUG) << "The " << i << "th args shape size is not the same, cur is " << cur_shape_seq->size()
726                       << " and the elem is " << top_cell_shape_seq->size();
727         return false;
728       }
729       for (size_t j = 0; j < cur_shape_size; ++j) {
730         MS_EXCEPTION_IF_NULL(cur_shape_seq->shape()[j]);
731         MS_EXCEPTION_IF_NULL(top_cell_shape_seq->shape()[j]);
732         if (!IsMatch(cur_shape_seq->shape()[j]->cast<abstract::ShapePtr>()->shape(),
733                      top_cell_shape_seq->shape()[j]->cast<abstract::ShapePtr>()->shape())) {
734           return false;
735         }
736       }
737     }
738   }
739   return true;
740 }
741 
ChangeTopCellToUnknownShape(const TopCellInfoPtr & top_cell,const abstract::BaseShapePtrList & args_unknown_shape)742 void TopCellUnknownShapeDetect::ChangeTopCellToUnknownShape(const TopCellInfoPtr &top_cell,
743                                                             const abstract::BaseShapePtrList &args_unknown_shape) {
744   if (top_cell->input_args_info()->input_arg_base_shape_vec.size() != args_unknown_shape.size()) {
745     MS_LOG(EXCEPTION) << "Top cell args base shape size "
746                       << top_cell->input_args_info()->input_arg_base_shape_vec.size()
747                       << " is not equal to update unknown shape size " << args_unknown_shape.size();
748   }
749   UpdateUnknownShapeAbsCache(top_cell->input_args_info()->input_arg_id_vec,
750                              top_cell->input_args_info()->input_arg_value_vec, args_unknown_shape);
751   top_cell->ChangeTopCellInfo(args_unknown_shape);
752 }
753 
SetTopCellUnknownShape(const TopCellInfoPtr & cur_top_cell,const TopCellInfoPtr & pre_top_cell,const abstract::BaseShapePtrList & args_shape)754 bool TopCellUnknownShapeDetect::SetTopCellUnknownShape(const TopCellInfoPtr &cur_top_cell,
755                                                        const TopCellInfoPtr &pre_top_cell,
756                                                        const abstract::BaseShapePtrList &args_shape) {
757   abstract::BaseShapePtrList args_unknown_shape;
758   args_unknown_shape.reserve(args_shape.size());
759   for (size_t i = 0; i < args_shape.size(); ++i) {
760     const auto &cur_shape = args_shape[i];
761     const auto &pre_top_cell_shape = pre_top_cell->input_args_info()->input_arg_base_shape_vec[i];
762     MS_EXCEPTION_IF_NULL(cur_shape);
763     MS_EXCEPTION_IF_NULL(pre_top_cell_shape);
764     if (cur_shape->isa<abstract::Shape>() && pre_top_cell_shape->isa<abstract::Shape>()) {
765       ShapeVector new_shape;
766       auto has_unknown = GetUnknownShape(cur_shape->cast<abstract::ShapePtr>()->shape(),
767                                          pre_top_cell_shape->cast<abstract::ShapePtr>()->shape(), &new_shape);
768       if (has_unknown) {
769         (void)args_unknown_shape.emplace_back(std::make_shared<abstract::Shape>(new_shape));
770       }
771     } else if (cur_shape->isa<abstract::SequenceShape>() && pre_top_cell_shape->isa<abstract::SequenceShape>()) {
772       // Input arg is list or tuple
773       const auto &cur_shape_seq = cur_shape->cast<abstract::SequenceShapePtr>();
774       MS_EXCEPTION_IF_NULL(cur_shape_seq);
775       const auto &pre_top_cell_shape_seq = pre_top_cell_shape->cast<abstract::SequenceShapePtr>();
776       size_t cur_shape_size = cur_shape_seq->size();
777       if (cur_shape_size != pre_top_cell_shape_seq->size()) {
778         MS_LOG(DEBUG) << "The " << i << "th args shape size is not the same, cur is " << cur_shape_seq->size()
779                       << " and the elem is " << pre_top_cell_shape_seq->size();
780       }
781       abstract::BaseShapePtrList shape_ptr_list;
782       for (size_t j = 0; j < cur_shape_size; ++j) {
783         const auto &cur_shape_elem = cur_shape_seq->shape()[j]->cast<abstract::ShapePtr>();
784         const auto &pre_top_cell_shape_elem = pre_top_cell_shape_seq->shape()[j]->cast<abstract::ShapePtr>();
785         MS_EXCEPTION_IF_NULL(pre_top_cell_shape_elem);
786         ShapeVector new_shape;
787         auto has_unknown = GetUnknownShape(cur_shape_elem->shape(), pre_top_cell_shape_elem->shape(), &new_shape);
788         if (has_unknown) {
789           (void)shape_ptr_list.emplace_back(std::make_shared<abstract::Shape>(new_shape));
790         }
791       }
792       if (shape_ptr_list.size() == cur_shape_size) {
793         (void)args_unknown_shape.emplace_back(std::make_shared<abstract::TupleShape>(shape_ptr_list));
794       }
795     } else {
796       MS_LOG(DEBUG) << "The " << i << "th args shape type is not the same, cur is " << cur_shape->ToString()
797                     << " and the elem is " << pre_top_cell_shape->ToString();
798       return false;
799     }
800   }
801   if (args_unknown_shape.size() == args_shape.size()) {
802     ChangeTopCellToUnknownShape(cur_top_cell, args_unknown_shape);
803     return true;
804   }
805   return false;
806 }
807 }  // namespace pynative
808 }  // namespace mindspore
809