1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pipeline/pynative/grad/ir/dynamic_shape.h"
18 #include <algorithm>
19 #include "pipeline/pynative/pynative_utils.h"
20
21 namespace mindspore {
22 namespace pynative {
23 namespace {
24 constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput";
25 constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
26 constexpr size_t kMaxCacheDynamicShapeCellNum = 2;
27
IsValuePtrEqual(const ValuePtr & v1,const ValuePtr & v2)28 bool IsValuePtrEqual(const ValuePtr &v1, const ValuePtr &v2) {
29 if (v1 == v2) {
30 return true;
31 }
32 if (v1 == nullptr || v2 == nullptr) {
33 return false;
34 }
35 if (v1->isa<tensor::BaseTensor>() && v2->isa<tensor::BaseTensor>()) {
36 return v1->cast<tensor::BaseTensorPtr>()->ValueEqual(*(v2->cast<tensor::BaseTensorPtr>()));
37 }
38 return *v1 == *v2;
39 }
40
IsDynamicDetectPrimChange(const PrimitivePtr & old_prim,const PrimitivePtr & new_prim)41 bool IsDynamicDetectPrimChange(const PrimitivePtr &old_prim, const PrimitivePtr &new_prim) {
42 if (old_prim == nullptr && new_prim == nullptr) {
43 return false;
44 }
45 // Use kernel graph will add kIsFeatureMapOutput adn kIsFeatureMapOutput attr,
46 // but check must be remove them
47 if (old_prim != nullptr && old_prim->HasAttr(kIsFeatureMapOutput)) {
48 old_prim->EraseAttr(kIsFeatureMapOutput);
49 old_prim->EraseAttr(kIsFeatureMapInputList);
50 }
51 if (new_prim != nullptr && old_prim != nullptr) {
52 return !common::IsEqual(old_prim, new_prim);
53 }
54 return true;
55 }
56
IsNodeInfoChange(const NodeInfo & old_node_info,const NodeInfo & new_node_info)57 bool IsNodeInfoChange(const NodeInfo &old_node_info, const NodeInfo &new_node_info) {
58 size_t input_size = old_node_info.seq_node.size();
59 if (input_size != new_node_info.seq_node.size()) {
60 MS_LOG(DEBUG) << "Graph is dynamic, input is tuple, but old seq node info size " << input_size
61 << ", new seq node info size " << new_node_info.seq_node.size();
62 return true;
63 } else {
64 for (size_t i = 0; i < input_size; ++i) {
65 if (IsNodeInfoChange(old_node_info.seq_node[i], new_node_info.seq_node[i])) {
66 return true;
67 }
68 }
69 }
70
71 if (new_node_info.grad_type == InputType::kParameter &&
72 (old_node_info.grad_type == InputType::kParameter || old_node_info.grad_type == InputType::kConstant)) {
73 MS_EXCEPTION_IF_NULL(new_node_info.value);
74 MS_EXCEPTION_IF_NULL(old_node_info.value);
75 auto new_tensor = new_node_info.value->cast<tensor::BaseTensorPtr>();
76 MS_EXCEPTION_IF_NULL(new_tensor);
77 auto old_tensor = old_node_info.value->cast<tensor::BaseTensorPtr>();
78 MS_EXCEPTION_IF_NULL(old_tensor);
79 if (new_tensor->id() != old_tensor->id()) {
80 MS_LOG(DEBUG) << "Graph is dynamic, new node info value: "
81 << (new_node_info.value != nullptr ? new_node_info.value->ToString() : "")
82 << ", grad type: " << new_node_info.grad_type << ", old node info value: "
83 << (old_node_info.value != nullptr ? old_node_info.value->ToString() : "")
84 << ", grad type: " << old_node_info.grad_type;
85 return true;
86 }
87 return false;
88 }
89
90 if (new_node_info.grad_type != old_node_info.grad_type) {
91 MS_LOG(DEBUG) << "Graph is dynamic, new node info grad type: " << new_node_info.grad_type
92 << ", old node info grad type: " << old_node_info.grad_type;
93 return true;
94 }
95
96 if (new_node_info.grad_type == InputType::kOpOutput && new_node_info.op_index != old_node_info.op_index) {
97 MS_LOG(DEBUG) << "Graph is dynamic, new node info op_index: " << new_node_info.op_index
98 << ", old node info op_index: " << old_node_info.op_index;
99 return true;
100 }
101
102 if (new_node_info.grad_type == InputType::kConstant && !IsValuePtrEqual(new_node_info.value, old_node_info.value)) {
103 MS_LOG(DEBUG) << "Graph is dynamic, new node info value: "
104 << (new_node_info.value != nullptr ? new_node_info.value->ToString() : "")
105 << ", grad type: " << new_node_info.grad_type << ", old node info value: "
106 << (old_node_info.value != nullptr ? old_node_info.value->ToString() : "")
107 << ", grad type: " << old_node_info.grad_type;
108 return true;
109 }
110
111 return false;
112 }
113
IsInputsNodeInfoChange(const std::vector<NodeInfo> & old_inputs_node_info,const std::vector<NodeInfo> & new_inputs_node_info)114 bool IsInputsNodeInfoChange(const std::vector<NodeInfo> &old_inputs_node_info,
115 const std::vector<NodeInfo> &new_inputs_node_info) {
116 size_t input_size = old_inputs_node_info.size();
117 if (input_size != new_inputs_node_info.size()) {
118 MS_LOG(DEBUG) << "Graph is dynamic, old_inputs size: " << input_size
119 << "new_inputs size: " << new_inputs_node_info.size();
120 return true;
121 }
122 for (size_t i = 0; i < input_size; ++i) {
123 if (IsNodeInfoChange(old_inputs_node_info[i], new_inputs_node_info[i])) {
124 return true;
125 }
126 }
127 return false;
128 }
129
GetNodeInfoFromValue(const ValuePtr & input)130 NodeInfo GetNodeInfoFromValue(const ValuePtr &input) {
131 if (input->isa<tensor::BaseTensor>()) {
132 NodeInfo node_info;
133 auto tensor = input->cast<tensor::BaseTensorPtr>();
134 auto auto_meta_data = tensor->auto_grad_meta_data();
135 // Scalar tensor
136 if (auto_meta_data == nullptr) {
137 node_info.grad_type = InputType::kConstant;
138 node_info.value = input;
139 return node_info;
140 }
141
142 // Tensor
143 node_info.grad_type = auto_meta_data->input_type();
144 node_info.op_index = auto_meta_data->op_index();
145 if (node_info.grad_type == InputType::kConstant || node_info.grad_type == InputType::kParameter) {
146 node_info.value = input;
147 }
148 return node_info;
149 } else if (input->isa<ValueSequence>()) {
150 NodeInfo node_info;
151 const auto &value_sequence = input->cast<ValueSequencePtr>();
152 for (const auto &i : value_sequence->value()) {
153 node_info.seq_node.emplace_back(GetNodeInfoFromValue(i));
154 }
155 } else if (input->isa<stub::StubNode>()) {
156 auto stub_node = input->cast<stub::StubNodePtr>();
157 MS_EXCEPTION_IF_NULL(stub_node);
158 GetNodeInfoFromValue(stub_node->WaitValue());
159 } else {
160 NodeInfo node_info;
161 node_info.grad_type = InputType::kConstant;
162 node_info.value = input;
163 return node_info;
164 }
165 return NodeInfo{};
166 }
167
168 struct CompareBasedOnAbstract {
IsNodeChangemindspore::pynative::__anonf68ef89a0111::CompareBasedOnAbstract169 static bool IsNodeChange(const ValuePtrList &inputs, const DynamicDetectNodeInfoPtr &old_node,
170 const DynamicDetectNodeInfoPtr &new_node) {
171 // Compare input abs
172 if (IsDynamicDetectAbsChange(old_node->abs_compare_info.input_abs, new_node->abs_compare_info.input_abs)) {
173 return true;
174 }
175
176 // Compare out abs
177 if (IsDynamicDetectAbsChange(old_node->abs_compare_info.out_abs, new_node->abs_compare_info.out_abs)) {
178 return true;
179 }
180
181 // Get input
182 BuildDynamicDetectInputsNodeInfo(new_node, inputs);
183
184 // Compare input
185 return IsInputsNodeInfoChange(old_node->abs_compare_info.inputs, new_node->abs_compare_info.inputs);
186 }
187
IsDynamicDetectAbsChangemindspore::pynative::__anonf68ef89a0111::CompareBasedOnAbstract188 static bool IsDynamicDetectAbsChange(const AbstractBasePtr &old_abs, const AbstractBasePtr &new_abs) {
189 if (old_abs == new_abs) {
190 return false;
191 }
192 if (old_abs == nullptr || new_abs == nullptr) {
193 MS_LOG(DEBUG) << "Graph is dynamic, old_abs is different with new_abs";
194 return true;
195 }
196 if (!common::IsEqual(old_abs->BuildType(), new_abs->BuildType()) ||
197 !common::IsEqual(old_abs->BuildShape(), new_abs->BuildShape())) {
198 MS_LOG(DEBUG) << "Graph is dynamic, old_abs is different with new_abs, old abs: " << old_abs->ToString()
199 << ", new abs: " << new_abs->ToString();
200 return true;
201 }
202 return false;
203 }
204
IsDynamicDetectAbsChangemindspore::pynative::__anonf68ef89a0111::CompareBasedOnAbstract205 static bool IsDynamicDetectAbsChange(const abstract::AbstractBasePtrList &node_abs,
206 const abstract::AbstractBasePtrList &old_node_abs) {
207 if (node_abs.size() != old_node_abs.size()) {
208 MS_LOG(DEBUG) << "Graph is dynamic, node_abs size: " << node_abs.size()
209 << ", old_node_abs size: " << old_node_abs.size();
210 return true;
211 }
212 for (size_t i = 0; i < node_abs.size(); ++i) {
213 if (IsDynamicDetectAbsChange(node_abs[i], old_node_abs[i])) {
214 return true;
215 }
216 }
217 return false;
218 }
219
BuildDynamicDetectInputsNodeInfomindspore::pynative::__anonf68ef89a0111::CompareBasedOnAbstract220 static void BuildDynamicDetectInputsNodeInfo(const DynamicDetectNodeInfoPtr &node, const ValuePtrList &inputs) {
221 std::transform(inputs.begin(), inputs.end(), std::back_inserter(node->abs_compare_info.inputs),
222 [](const auto &item) { return GetNodeInfoFromValue(item); });
223 }
224 };
225
226 struct CompareBasedOnValueSimpleInfo {
IsNodeChangemindspore::pynative::__anonf68ef89a0111::CompareBasedOnValueSimpleInfo227 static bool IsNodeChange(const ValuePtrList &inputs, const DynamicDetectNodeInfoPtr &old_node,
228 const DynamicDetectNodeInfoPtr &new_node) {
229 BuildInputsValueSimpleInfo(new_node, inputs);
230 return IsInputsChange(old_node->value_compare_info, new_node->value_compare_info);
231 }
232
BuildInputsValueSimpleInfomindspore::pynative::__anonf68ef89a0111::CompareBasedOnValueSimpleInfo233 static void BuildInputsValueSimpleInfo(const DynamicDetectNodeInfoPtr &node, const ValuePtrList &inputs) {
234 size_t input_size = inputs.size();
235 node->value_compare_info.input_value_simple_info.size_ = input_size;
236 node->value_compare_info.input_value_simple_info.shape_vector_.reserve(input_size);
237 node->value_compare_info.input_value_simple_info.dtype_vector_.reserve(input_size);
238 node->value_compare_info.input_value_simple_info.object_type_vector_.reserve(input_size);
239 for (const auto &input : inputs) {
240 node->value_compare_info.inputs.emplace_back(GetNodeInfoFromValue(input));
241
242 (void)node->value_compare_info.input_value_simple_info.shape_vector_.emplace_back(
243 PyNativeAlgo::Common::GetShapeFromValue(input));
244 auto [dtype, obj_type] = PyNativeAlgo::Common::GetTypeFromValue(input);
245 (void)node->value_compare_info.input_value_simple_info.dtype_vector_.emplace_back(dtype);
246 (void)node->value_compare_info.input_value_simple_info.object_type_vector_.emplace_back(obj_type);
247 }
248 }
249
IsInputsChangemindspore::pynative::__anonf68ef89a0111::CompareBasedOnValueSimpleInfo250 static bool IsInputsChange(const ValueCompareInfo &old_value_compare_info,
251 const ValueCompareInfo &new_value_compare_info) {
252 if (IsInputsNodeInfoChange(old_value_compare_info.inputs, new_value_compare_info.inputs)) {
253 return true;
254 }
255 return IsValueSimpleInfoChange(old_value_compare_info.input_value_simple_info,
256 new_value_compare_info.input_value_simple_info);
257 }
258
259 template <typename T1, typename T2>
IsNotEuqalmindspore::pynative::__anonf68ef89a0111::CompareBasedOnValueSimpleInfo260 static bool IsNotEuqal(const T1 &old_input, const T2 &new_input) {
261 return old_input != new_input;
262 }
263
264 template <typename T1, typename T2>
IsNotEuqalmindspore::pynative::__anonf68ef89a0111::CompareBasedOnValueSimpleInfo265 static bool IsNotEuqal(const std::shared_ptr<T1> &old_input, const std::shared_ptr<T2> &new_input) {
266 MS_EXCEPTION_IF_NULL(old_input);
267 MS_EXCEPTION_IF_NULL(new_input);
268 return old_input->type_id() != new_input->type_id();
269 }
270
IsValueSimpleInfoChangemindspore::pynative::__anonf68ef89a0111::CompareBasedOnValueSimpleInfo271 static bool IsValueSimpleInfoChange(const ValueSimpleInfo &old_input_simple_info,
272 const ValueSimpleInfo &new_input_simple_info) {
273 if (old_input_simple_info.size_ != new_input_simple_info.size_) {
274 MS_LOG(DEBUG) << "Graph is dynamic, old_input_simple_info size: " << old_input_simple_info.size_
275 << ", new_input_simple_info size: " << new_input_simple_info.size_;
276 return true;
277 }
278 for (size_t i = 0; i < old_input_simple_info.size_; ++i) {
279 if (IsNotEuqal(old_input_simple_info.shape_vector_[i], new_input_simple_info.shape_vector_[i]) ||
280 IsNotEuqal(old_input_simple_info.dtype_vector_[i], new_input_simple_info.dtype_vector_[i]) ||
281 IsNotEuqal(old_input_simple_info.object_type_vector_[i], new_input_simple_info.object_type_vector_[i])) {
282 MS_LOG(DEBUG) << "Graph is dynamic, old input simple info: " << ValueSimpleInfoToString(old_input_simple_info)
283 << ", new input simple info: " << ValueSimpleInfoToString(new_input_simple_info);
284 return true;
285 }
286 }
287 return false;
288 }
289 };
290
UpdateAbsCache(const std::string & arg_id,const ValuePtr & v,const abstract::BaseShapePtr & base_shape,const abstract::AbstractBasePtr & abs,size_t index)291 void UpdateAbsCache(const std::string &arg_id, const ValuePtr &v, const abstract::BaseShapePtr &base_shape,
292 const abstract::AbstractBasePtr &abs, size_t index) {
293 auto update_abs = abs;
294 if (update_abs == nullptr) {
295 MS_EXCEPTION_IF_NULL(v);
296 auto input_tensor = v->cast<tensor::BaseTensorPtr>();
297 // Just tensor work in unknown shape
298 if (input_tensor == nullptr) {
299 return;
300 }
301 MS_EXCEPTION_IF_NULL(base_shape);
302 update_abs = std::make_shared<abstract::AbstractTensor>(input_tensor->Dtype(), base_shape);
303 }
304 MS_LOG(DEBUG) << "Set arg " << index << ", id " << arg_id << ", to dynamic abs: " << update_abs->ToString();
305 const auto &infer = PyNativeAlgo::Common::GetPyNativeExecutor()->forward_executor()->infer_operation();
306 infer->UpdateNodeAbsCacheById(arg_id, update_abs);
307 }
308
GetUnknownShape(const ShapeVector & cur_shape,const ShapeVector & pre_top_cell_shape,ShapeVector * new_shape)309 bool GetUnknownShape(const ShapeVector &cur_shape, const ShapeVector &pre_top_cell_shape, ShapeVector *new_shape) {
310 // Dynamic rank
311 if (cur_shape.size() != pre_top_cell_shape.size()) {
312 MS_LOG(INFO) << "Cur shape size " << cur_shape.size() << " is not equal to top cell arg shape size "
313 << pre_top_cell_shape.size();
314 (void)new_shape->emplace_back(abstract::Shape::kShapeRankAny);
315 return true;
316 }
317 // Dynamic shape
318 for (size_t j = 0; j < cur_shape.size(); ++j) {
319 if (cur_shape[j] == pre_top_cell_shape[j]) {
320 (void)new_shape->emplace_back(cur_shape[j]);
321 } else {
322 (void)new_shape->emplace_back(abstract::Shape::kShapeDimAny);
323 }
324 }
325 // All shape can not be actual, which indicates static shape.
326 if (!IsDynamicShape(*new_shape)) {
327 MS_LOG(DEBUG) << "All shape are actual, is static shape. Cur shape " << cur_shape << ", elem shape "
328 << pre_top_cell_shape << ", and new shape is " << new_shape;
329 return false;
330 }
331 return true;
332 }
333
IsMatch(const ShapeVector & cur_shape,const ShapeVector & pre_top_cell_shape)334 bool IsMatch(const ShapeVector &cur_shape, const ShapeVector &pre_top_cell_shape) {
335 if (cur_shape.size() != pre_top_cell_shape.size() && !pre_top_cell_shape.empty() &&
336 pre_top_cell_shape[kIndex0] != abstract::Shape::kShapeRankAny) {
337 MS_LOG(DEBUG) << "Cur shape size " << cur_shape.size() << " is not equal to pre top cell arg shape size "
338 << pre_top_cell_shape.size();
339 return false;
340 }
341 // Dynamic rank or dynamic shape
342 for (size_t i = 0; i < cur_shape.size(); ++i) {
343 if (cur_shape[i] != pre_top_cell_shape[i] && pre_top_cell_shape[i] != abstract::Shape::kShapeDimAny) {
344 MS_LOG(DEBUG) << "Cur shape " << cur_shape[i] << " can not match pre top cell shape " << pre_top_cell_shape[i];
345 return false;
346 }
347 }
348 return true;
349 }
350 } // namespace
351
GetDynamicInput(const py::object & actual_input)352 py::object DynamicShape::GetDynamicInput(const py::object &actual_input) {
353 if (py::isinstance<py::tuple>(actual_input)) {
354 auto tuple_actual_args = py::cast<py::tuple>(actual_input);
355 size_t args_size = tuple_actual_args.size();
356 py::tuple dyn_shape_args = py::tuple(args_size);
357 for (size_t i = 0; i < args_size; ++i) {
358 dyn_shape_args[i] = GetDynamicInput(tuple_actual_args[i]);
359 }
360 return dyn_shape_args;
361 } else if (py::isinstance<py::list>(actual_input)) {
362 auto list_actual_args = py::cast<py::list>(actual_input);
363 size_t args_size = list_actual_args.size();
364 py::list dyn_shape_args;
365 for (size_t i = 0; i < args_size; ++i) {
366 dyn_shape_args.append(GetDynamicInput(list_actual_args[i]));
367 }
368 return dyn_shape_args;
369 } else if (py::isinstance<tensor::BaseTensor>(actual_input)) {
370 const auto &infer = PyNativeAlgo::Common::GetPyNativeExecutor()->forward_executor()->infer_operation();
371 auto tensor_ptr = py::cast<tensor::BaseTensorPtr>(actual_input);
372 MS_EXCEPTION_IF_NULL(tensor_ptr);
373 auto dyn_compile_tensor = std::make_shared<tensor::BaseTensor>(tensor_ptr->data_type(), tensor_ptr->shape_c());
374 const auto &abs = infer->GetNodeAbsById(PyNativeAlgo::PyParser::GetIdByPyObj(actual_input));
375 if (abs != nullptr) {
376 auto base_shape = abs->BuildShape();
377 MS_EXCEPTION_IF_NULL(base_shape);
378 if (base_shape->IsDynamic()) {
379 dyn_compile_tensor->set_base_shape(base_shape);
380 }
381 }
382 return PyNativeAlgo::DataConvert::ValueToPyObj(dyn_compile_tensor);
383 }
384 return actual_input;
385 }
386
SaveUnknownShapeAbsFromJit(const ValuePtr & v,const AbstractBasePtr & abs,size_t index)387 void DynamicShape::SaveUnknownShapeAbsFromJit(const ValuePtr &v, const AbstractBasePtr &abs, size_t index) {
388 MS_EXCEPTION_IF_NULL(v);
389 MS_EXCEPTION_IF_NULL(abs);
390 if (v->isa<ValueSequence>() && abs->isa<abstract::AbstractSequence>()) {
391 const auto &v_seq = v->cast<ValueSequencePtr>();
392 const auto &abs_seq = abs->cast<abstract::AbstractSequencePtr>();
393 if (v_seq->size() != abs_seq->size()) {
394 MS_LOG(EXCEPTION) << "Obj tuple size " << v_seq->size() << ", but abstract tuple size " << abs_seq->size();
395 }
396 for (size_t i = 0; i < v_seq->size(); ++i) {
397 SaveUnknownShapeAbsFromJit(v_seq->value()[i], abs_seq->elements()[i], index);
398 }
399 } else if (v->isa<tensor::BaseTensor>() && abs->isa<abstract::AbstractTensor>()) {
400 if (abs->BuildShape()->IsDynamic()) {
401 UpdateAbsCache(PyNativeAlgo::Common::GetIdByValue(v), v, nullptr, abs, ++index);
402 }
403 } else {
404 MS_LOG(EXCEPTION) << "Not match: obj " << v->ToString() << " and abs " << abs->ToString();
405 }
406 }
407
CheckNodeDynamic(const TopCellInfoPtr & top_cell,const ValuePtrList & inputs,const DynamicDetectNodeInfoPtr & node)408 bool NodeDynamicDetect::CheckNodeDynamic(const TopCellInfoPtr &top_cell, const ValuePtrList &inputs,
409 const DynamicDetectNodeInfoPtr &node) {
410 std::unique_lock<std::mutex> lock(async_mutex_);
411 MS_EXCEPTION_IF_NULL(top_cell);
412 if (top_cell->use_dynamic_shape_process()) {
413 top_cell->IncreaseOpIndex();
414 return true;
415 }
416
417 const size_t node_idx = top_cell->op_index();
418 bool node_is_dynamic = false;
419 bool use_dynamic_shape_process =
420 top_cell->has_bprop_cut_op() || (node_is_dynamic = IsNodeDynamic(top_cell, inputs, node, node_idx)) == true;
421 top_cell->IncreaseOpIndex();
422 if (use_dynamic_shape_process) {
423 MS_LOG(INFO) << "Set use_dynamic_shape_process: " << use_dynamic_shape_process;
424 top_cell->set_use_dynamic_shape_process(use_dynamic_shape_process);
425 py::gil_scoped_acquire gil_acquire;
426 (void)cell_id_with_dynamic_detect_nodes_.erase(top_cell->obj_id_with_grad_order());
427 }
428 if (node_is_dynamic) {
429 auto context = MsContext::GetInstance();
430 MS_EXCEPTION_IF_NULL(context);
431 if (context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE)) {
432 MS_LOG(WARNING) << "Detect dynamic shape or dynamic graph structure, the python stack is: ";
433 py::gil_scoped_acquire acquire_gil;
434 py::exec(R"(
435 import traceback
436 traceback.print_stack()
437 )");
438 }
439 }
440 return use_dynamic_shape_process;
441 }
442
IsNodeDynamic(const TopCellInfoPtr & top_cell,const ValuePtrList & inputs,const DynamicDetectNodeInfoPtr & node,size_t node_idx)443 bool NodeDynamicDetect::IsNodeDynamic(const TopCellInfoPtr &top_cell, const ValuePtrList &inputs,
444 const DynamicDetectNodeInfoPtr &node, size_t node_idx) {
445 MS_EXCEPTION_IF_NULL(node);
446 if (top_cell->is_need_save_dynamic_detect_nodes()) {
447 SaveDynamicDetectNodeInfoInFirstTime(top_cell, inputs, node, node_idx);
448 // The net is regarded as a static net by default in the first time.
449 return false;
450 }
451
452 MS_LOG(DEBUG) << "Check node " << (node->op_prim != nullptr ? node->op_prim->name() : "") << " node_idx: " << node_idx
453 << ", is_jit_node: " << node->is_graph_node << ", graph_phase: " << node->graph_phase
454 << ", obj_id_with_grad_order: " << top_cell->obj_id_with_grad_order()
455 << ", cell id: " << top_cell->cell_id();
456 const auto &dynamic_nodes =
457 cell_id_with_dynamic_detect_nodes_[top_cell->obj_id_with_grad_order()][top_cell->cell_id()];
458 if (node_idx >= dynamic_nodes.size()) {
459 MS_LOG(DEBUG) << "Old dynamic_nodes size: " << dynamic_nodes.size() << ", cur node_idx is: " << node_idx
460 << ", graph is dynamic.";
461 return true;
462 }
463
464 // 1.Detect jit phase
465 const DynamicDetectNodeInfoPtr &old_node_info = dynamic_nodes[node_idx];
466 if (node->is_graph_node) {
467 if (!old_node_info->is_graph_node || node->graph_phase != old_node_info->graph_phase) {
468 MS_LOG(DEBUG) << "Graph is dynamic, old is_graph_node: " << old_node_info->is_graph_node
469 << ", new is_graph_node: " << node->is_graph_node << ", old graph_phase "
470 << old_node_info->is_graph_node << ", new graph_phase: " << node->graph_phase;
471 return true;
472 }
473 return false;
474 }
475
476 // 2.Detect prim
477 if (IsDynamicDetectPrimChange(old_node_info->op_prim, node->op_prim)) {
478 MS_LOG(DEBUG) << "Graph is dynamic, old node prim: "
479 << (old_node_info->op_prim != nullptr
480 ? old_node_info->op_prim->name() + ", attr: " + old_node_info->op_prim->GetAttrsText()
481 : "")
482 << " new node prim: "
483 << (node->op_prim != nullptr ? node->op_prim->name() + ", attr: " + node->op_prim->GetAttrsText()
484 : "")
485 << " node_idx: " << node_idx;
486 return true;
487 }
488
489 // 3.Detect inputs
490 if (node->is_value_compare) {
491 return CompareBasedOnValueSimpleInfo::IsNodeChange(inputs, old_node_info, node);
492 } else {
493 return CompareBasedOnAbstract::IsNodeChange(inputs, old_node_info, node);
494 }
495 }
496
SaveDynamicDetectNodeInfoInFirstTime(const TopCellInfoPtr & top_cell,const ValuePtrList & inputs,const DynamicDetectNodeInfoPtr & node,size_t node_idx)497 void NodeDynamicDetect::SaveDynamicDetectNodeInfoInFirstTime(const TopCellInfoPtr &top_cell, const ValuePtrList &inputs,
498 const DynamicDetectNodeInfoPtr &node, size_t node_idx) {
499 MS_EXCEPTION_IF_NULL(node);
500 if (node->is_value_compare) {
501 CompareBasedOnValueSimpleInfo::BuildInputsValueSimpleInfo(node, inputs);
502 } else {
503 CompareBasedOnAbstract::BuildDynamicDetectInputsNodeInfo(node, inputs);
504 }
505 (void)cell_id_with_dynamic_detect_nodes_[top_cell->obj_id_with_grad_order()][top_cell->cell_id()].emplace_back(node);
506 MS_LOG(DEBUG) << "Save node " << (node->op_prim != nullptr ? node->op_prim->name() : "")
507 << " firstly, node_idx: " << node_idx << ", is_jit_node: " << node->is_graph_node
508 << ", graph_phase: " << node->graph_phase
509 << ", obj_id_with_grad_order: " << top_cell->obj_id_with_grad_order()
510 << ", cell id: " << top_cell->cell_id();
511 }
512
IsNeedSaveDynamicDetectNodes(const TopCellInfoPtr & top_cell,bool use_dynamic_shape_process)513 bool NodeDynamicDetect::IsNeedSaveDynamicDetectNodes(const TopCellInfoPtr &top_cell, bool use_dynamic_shape_process) {
514 if (use_dynamic_shape_process) {
515 // top cell is already dynamic shape, no need save nodes.
516 return false;
517 }
518 MS_EXCEPTION_IF_NULL(top_cell);
519 auto cell_iter = cell_id_with_dynamic_detect_nodes_.find(top_cell->obj_id_with_grad_order());
520 if (cell_iter == cell_id_with_dynamic_detect_nodes_.end()) {
521 // Cell is not found in cell_id_with_dynamic_detect_nodes_, need save nodes first.
522 return true;
523 }
524
525 const auto &cell_infos = cell_iter->second;
526 if (cell_infos.size() == 1) {
527 // top_cell->cell_id() is cell id with inputs shape, if cell id in cell_id_with_dynamic_detect_nodes_
528 // id same with top_cell->cell_id(), no need save nodes.
529 return cell_infos.begin()->first != top_cell->cell_id();
530 } else if (cell_infos.size() == kMaxCacheDynamicShapeCellNum) {
531 auto cell_infos_iter = cell_infos.find(top_cell->cell_id());
532 if (cell_infos_iter == cell_infos.end()) {
533 // cell_id_with_dynamic_detect_nodes_ has two cell id already, current cell is is different
534 // with them. So set_use_dynamic_shape_process for top cell.
535 top_cell->set_use_dynamic_shape_process(true);
536 (void)cell_id_with_dynamic_detect_nodes_.erase(top_cell->obj_id_with_grad_order());
537 MS_LOG(INFO) << "Set use_dynamic_shape_process: " << use_dynamic_shape_process << ", already cached "
538 << cell_infos.size() << " top cell, cur top cell shape is different: " << top_cell->cell_id();
539 }
540 } else {
541 MS_LOG(EXCEPTION) << "cell_info.size(): " << cell_infos.size() << " is invalid";
542 }
543 return false;
544 }
545
SetDynamicInput(const py::object & obj,const py::args & args)546 void TopCellUnknownShapeDetect::SetDynamicInput(const py::object &obj, const py::args &args) {
547 const auto &obj_id = PyNativeAlgo::PyParser::GetIdByPyObj(obj);
548 // After first step, set inputs no need work again. Because the top cell of first step is already unknown shape and
549 // follow step will keep unknown shape always, special input_signature
550 if (obj_with_by_inputs_.find(obj_id) != obj_with_by_inputs_.end()) {
551 MS_LOG(DEBUG) << "Obj " << obj_id << " has done set inputs before";
552 return;
553 }
554 auto &arg_base_shape_vec = obj_id_args_info_by_set_inputs_[obj_id];
555 size_t args_size = args.size();
556 arg_base_shape_vec.reserve(args_size);
557 for (size_t i = 0; i < args_size; ++i) {
558 (void)arg_base_shape_vec.emplace_back(PyNativeAlgo::DataConvert::PyObjToValue(args[i])->ToAbstract()->BuildShape());
559 }
560 TryChangeTopCellToUnknownShape(obj_id, arg_base_shape_vec, false);
561 (void)obj_with_by_inputs_.emplace(obj_id);
562 }
563
TryChangeTopCellToUnknownShape(const std::string & obj_id,const abstract::BaseShapePtrList & arg_base_shape_vec,bool is_auto_detect)564 void TopCellUnknownShapeDetect::TryChangeTopCellToUnknownShape(const std::string &obj_id,
565 const abstract::BaseShapePtrList &arg_base_shape_vec,
566 bool is_auto_detect) {
567 const auto &grad_executor = PyNativeAlgo::Common::GetPyNativeExecutor()->grad_executor();
568 if (is_auto_detect) {
569 // From auto detect
570 auto &top_cell_list = grad_executor->already_run_top_cell();
571 const auto it = std::find_if(top_cell_list.begin(), top_cell_list.end(), [&obj_id](const auto &elem) {
572 return elem.second->input_args_info() != nullptr && elem.second->input_args_info()->obj_id == obj_id;
573 });
574 if (it != top_cell_list.end()) {
575 // Pre top cell is already unknown shape, check current top cell can match it
576 if (it->second->is_unknown_shape() && CanFindMatchedUnknownShapeTopCell(it->second, arg_base_shape_vec)) {
577 MS_LOG(DEBUG) << "Pre top cell has already been unknown shape and can match current top cell";
578 ChangeTopCellToUnknownShape(grad_executor->top_cell(), it->second->input_args_info()->input_arg_base_shape_vec);
579 return;
580 }
581 // If not match before, compare shape and change current top cell do unknown shape
582 if (SetTopCellUnknownShape(grad_executor->top_cell(), it->second, arg_base_shape_vec)) {
583 (void)top_cell_list.erase(it);
584 return;
585 }
586 } else {
587 // Set inputs, first step top cell working here
588 const auto item = obj_id_args_info_by_set_inputs_.find(grad_executor->top_cell()->input_args_info()->obj_id);
589 if (item != obj_id_args_info_by_set_inputs_.end()) {
590 const auto &input_args_info = grad_executor->top_cell()->input_args_info();
591 UpdateUnknownShapeAbsCache(input_args_info->input_arg_id_vec, input_args_info->input_arg_value_vec,
592 item->second);
593 (void)obj_id_args_info_by_set_inputs_.erase(item);
594 return;
595 }
596 // C1.set_inputs, run C1(x); C2 is top cell, and run C2(x).
597 if (std::any_of(arg_base_shape_vec.begin(), arg_base_shape_vec.end(),
598 [](const abstract::BaseShapePtr &base_shape) { return base_shape->IsDynamic(); })) {
599 MS_LOG(DEBUG) << "Top cell is unknown shape now";
600 grad_executor->top_cell()->set_is_unknown_shape(true);
601 }
602 }
603 } else {
604 // From set inputs. Has not create top cell yet
605 if (grad_executor->TopCellHasNotBeenCreate()) {
606 return;
607 }
608 // Jit, top cell create first, then set inputs run
609 const auto item = obj_id_args_info_by_set_inputs_.find(grad_executor->top_cell()->input_args_info()->obj_id);
610 if (item != obj_id_args_info_by_set_inputs_.end()) {
611 MS_LOG(DEBUG) << "Get jit set inputs";
612 ChangeTopCellToUnknownShape(grad_executor->top_cell(), arg_base_shape_vec);
613 (void)obj_id_args_info_by_set_inputs_.erase(item);
614 }
615 }
616 }
617
UpdateUnknownShapeAbsCache(const std::vector<string> & input_arg_id_vec,const std::vector<ValuePtr> & input_arg_value_vec,const std::vector<abstract::BaseShapePtr> & args_base_shape)618 void TopCellUnknownShapeDetect::UpdateUnknownShapeAbsCache(const std::vector<string> &input_arg_id_vec,
619 const std::vector<ValuePtr> &input_arg_value_vec,
620 const std::vector<abstract::BaseShapePtr> &args_base_shape) {
621 for (size_t i = 0; i < args_base_shape.size(); i++) {
622 MS_EXCEPTION_IF_NULL(args_base_shape[i]);
623 MS_EXCEPTION_IF_NULL(input_arg_value_vec[i]);
624 if (args_base_shape[i]->IsDynamic()) {
625 if (args_base_shape[i]->isa<abstract::Shape>()) {
626 UpdateAbsCache(input_arg_id_vec[i], input_arg_value_vec[i], args_base_shape[i], nullptr, i);
627 } else if (args_base_shape[i]->isa<abstract::SequenceShape>()) {
628 // Input arg is list or tuple
629 const auto &seq_shape = args_base_shape[i]->cast<abstract::SequenceShapePtr>();
630 const auto &seq_v = input_arg_value_vec[i]->cast<ValueSequencePtr>();
631 MS_EXCEPTION_IF_NULL(seq_v);
632 if (seq_v->size() != seq_shape->size()) {
633 MS_LOG(EXCEPTION) << "Sequence value size " << seq_v->size() << " is not equal to seq shape size "
634 << seq_shape->size();
635 }
636 std::vector<std::string> id_vec;
637 PyNativeAlgo::Common::SplitString(input_arg_id_vec[i], &id_vec);
638 if (id_vec.size() != seq_shape->size()) {
639 MS_LOG(EXCEPTION) << "Id size " << id_vec.size() << " is not equal to seq shape size " << seq_shape->size();
640 }
641 for (size_t j = 0; j < seq_shape->size(); ++j) {
642 UpdateAbsCache(id_vec[j], seq_v->value()[j], seq_shape->shape()[j], nullptr, i + j);
643 }
644 }
645 }
646 }
647 }
648
UpdateArgsAbsToUnknownShapeAbs(const py::object & obj,const py::args & args)649 void TopCellUnknownShapeDetect::UpdateArgsAbsToUnknownShapeAbs(const py::object &obj, const py::args &args) {
650 if (obj_id_args_info_by_set_inputs_.empty()) {
651 return;
652 }
653
654 const auto &grad_executor = PyNativeAlgo::Common::GetPyNativeExecutor()->grad_executor();
655 bool top_cell_has_not_been_create = grad_executor->TopCellHasNotBeenCreate();
656 // Top cell is already unknown shape
657 if (!top_cell_has_not_been_create && grad_executor->top_cell()->is_unknown_shape()) {
658 return;
659 }
660
661 // Current cell is has no set_inputs
662 const auto &obj_id = PyNativeAlgo::PyParser::GetIdByPyObj(obj);
663 const auto it = obj_id_args_info_by_set_inputs_.find(obj_id);
664 if (it == obj_id_args_info_by_set_inputs_.end()) {
665 return;
666 }
667
668 // Common cell args id and value not create in ParsePyArgsToInputArgsInfo, need get them now.
669 // Update current cell id cache which maybe used for top cell
670 const auto &args_id_v = PyNativeAlgo::PyParser::GetArgsIdAndValue(args);
671 UpdateUnknownShapeAbsCache(args_id_v.first, args_id_v.second, it->second);
672
673 // C1.set_inputs, run C1(x); C2 is top cell, and run C2(x).
674 if (top_cell_has_not_been_create) {
675 // Has not create top cell yet
676 (void)obj_id_args_info_by_set_inputs_.erase(it);
677 return;
678 }
679
680 // C1 is top cell, run C1(x); C2 set_inputs, and run C2(x).
681 UpdatePossibleTopCellToUnknownShape(grad_executor->top_cell(), args_id_v.first, it->second);
682 (void)obj_id_args_info_by_set_inputs_.erase(it);
683 }
684
UpdatePossibleTopCellToUnknownShape(const TopCellInfoPtr & cur_top_cell,const std::vector<string> & cur_arg_id_vec,const abstract::BaseShapePtrList & cur_args_shape)685 void TopCellUnknownShapeDetect::UpdatePossibleTopCellToUnknownShape(const TopCellInfoPtr &cur_top_cell,
686 const std::vector<string> &cur_arg_id_vec,
687 const abstract::BaseShapePtrList &cur_args_shape) {
688 MS_LOG(DEBUG) << "Update possible top cell";
689 auto cur_top_cell_base_shape_vec = cur_top_cell->input_args_info()->input_arg_base_shape_vec;
690 const auto &cur_top_cell_id_vec = cur_top_cell->input_args_info()->input_arg_id_vec;
691 bool need_change_top_cell_info = false;
692 // Check top cell args id is the same with current set inputs cell. If dynamic shape, update top cell to unknown shape
693 for (size_t i = 0; i < cur_arg_id_vec.size(); ++i) {
694 auto it = std::find(cur_top_cell_id_vec.begin(), cur_top_cell_id_vec.end(), cur_arg_id_vec[i]);
695 if (it != cur_top_cell_id_vec.end() && cur_args_shape[i]->IsDynamic()) {
696 auto id_index = it - cur_top_cell_id_vec.begin();
697 cur_top_cell_base_shape_vec[id_index] = cur_args_shape[i];
698 need_change_top_cell_info = true;
699 }
700 }
701 // Change current top cell info
702 if (need_change_top_cell_info) {
703 cur_top_cell->ChangeTopCellInfo(cur_top_cell_base_shape_vec);
704 }
705 }
706
CanFindMatchedUnknownShapeTopCell(const TopCellInfoPtr & pre_top_cell,const abstract::BaseShapePtrList & cur_args_shape)707 bool TopCellUnknownShapeDetect::CanFindMatchedUnknownShapeTopCell(const TopCellInfoPtr &pre_top_cell,
708 const abstract::BaseShapePtrList &cur_args_shape) {
709 for (size_t i = 0; i < cur_args_shape.size(); ++i) {
710 const auto &cur_shape = cur_args_shape[i];
711 const auto &pre_top_cell_shape = pre_top_cell->input_args_info()->input_arg_base_shape_vec[i];
712 MS_EXCEPTION_IF_NULL(cur_shape);
713 MS_EXCEPTION_IF_NULL(pre_top_cell_shape);
714 if (cur_shape->isa<abstract::Shape>() && pre_top_cell_shape->isa<abstract::Shape>()) {
715 if (!IsMatch(cur_shape->cast<abstract::ShapePtr>()->shape(),
716 pre_top_cell_shape->cast<abstract::ShapePtr>()->shape())) {
717 return false;
718 }
719 } else if (cur_shape->isa<abstract::SequenceShape>() && pre_top_cell_shape->isa<abstract::SequenceShape>()) {
720 // Input arg is list or tuple
721 const auto &cur_shape_seq = cur_shape->cast<abstract::SequenceShapePtr>();
722 const auto &top_cell_shape_seq = pre_top_cell_shape->cast<abstract::SequenceShapePtr>();
723 size_t cur_shape_size = cur_shape_seq->size();
724 if (cur_shape_size != top_cell_shape_seq->size()) {
725 MS_LOG(DEBUG) << "The " << i << "th args shape size is not the same, cur is " << cur_shape_seq->size()
726 << " and the elem is " << top_cell_shape_seq->size();
727 return false;
728 }
729 for (size_t j = 0; j < cur_shape_size; ++j) {
730 MS_EXCEPTION_IF_NULL(cur_shape_seq->shape()[j]);
731 MS_EXCEPTION_IF_NULL(top_cell_shape_seq->shape()[j]);
732 if (!IsMatch(cur_shape_seq->shape()[j]->cast<abstract::ShapePtr>()->shape(),
733 top_cell_shape_seq->shape()[j]->cast<abstract::ShapePtr>()->shape())) {
734 return false;
735 }
736 }
737 }
738 }
739 return true;
740 }
741
ChangeTopCellToUnknownShape(const TopCellInfoPtr & top_cell,const abstract::BaseShapePtrList & args_unknown_shape)742 void TopCellUnknownShapeDetect::ChangeTopCellToUnknownShape(const TopCellInfoPtr &top_cell,
743 const abstract::BaseShapePtrList &args_unknown_shape) {
744 if (top_cell->input_args_info()->input_arg_base_shape_vec.size() != args_unknown_shape.size()) {
745 MS_LOG(EXCEPTION) << "Top cell args base shape size "
746 << top_cell->input_args_info()->input_arg_base_shape_vec.size()
747 << " is not equal to update unknown shape size " << args_unknown_shape.size();
748 }
749 UpdateUnknownShapeAbsCache(top_cell->input_args_info()->input_arg_id_vec,
750 top_cell->input_args_info()->input_arg_value_vec, args_unknown_shape);
751 top_cell->ChangeTopCellInfo(args_unknown_shape);
752 }
753
SetTopCellUnknownShape(const TopCellInfoPtr & cur_top_cell,const TopCellInfoPtr & pre_top_cell,const abstract::BaseShapePtrList & args_shape)754 bool TopCellUnknownShapeDetect::SetTopCellUnknownShape(const TopCellInfoPtr &cur_top_cell,
755 const TopCellInfoPtr &pre_top_cell,
756 const abstract::BaseShapePtrList &args_shape) {
757 abstract::BaseShapePtrList args_unknown_shape;
758 args_unknown_shape.reserve(args_shape.size());
759 for (size_t i = 0; i < args_shape.size(); ++i) {
760 const auto &cur_shape = args_shape[i];
761 const auto &pre_top_cell_shape = pre_top_cell->input_args_info()->input_arg_base_shape_vec[i];
762 MS_EXCEPTION_IF_NULL(cur_shape);
763 MS_EXCEPTION_IF_NULL(pre_top_cell_shape);
764 if (cur_shape->isa<abstract::Shape>() && pre_top_cell_shape->isa<abstract::Shape>()) {
765 ShapeVector new_shape;
766 auto has_unknown = GetUnknownShape(cur_shape->cast<abstract::ShapePtr>()->shape(),
767 pre_top_cell_shape->cast<abstract::ShapePtr>()->shape(), &new_shape);
768 if (has_unknown) {
769 (void)args_unknown_shape.emplace_back(std::make_shared<abstract::Shape>(new_shape));
770 }
771 } else if (cur_shape->isa<abstract::SequenceShape>() && pre_top_cell_shape->isa<abstract::SequenceShape>()) {
772 // Input arg is list or tuple
773 const auto &cur_shape_seq = cur_shape->cast<abstract::SequenceShapePtr>();
774 MS_EXCEPTION_IF_NULL(cur_shape_seq);
775 const auto &pre_top_cell_shape_seq = pre_top_cell_shape->cast<abstract::SequenceShapePtr>();
776 size_t cur_shape_size = cur_shape_seq->size();
777 if (cur_shape_size != pre_top_cell_shape_seq->size()) {
778 MS_LOG(DEBUG) << "The " << i << "th args shape size is not the same, cur is " << cur_shape_seq->size()
779 << " and the elem is " << pre_top_cell_shape_seq->size();
780 }
781 abstract::BaseShapePtrList shape_ptr_list;
782 for (size_t j = 0; j < cur_shape_size; ++j) {
783 const auto &cur_shape_elem = cur_shape_seq->shape()[j]->cast<abstract::ShapePtr>();
784 const auto &pre_top_cell_shape_elem = pre_top_cell_shape_seq->shape()[j]->cast<abstract::ShapePtr>();
785 MS_EXCEPTION_IF_NULL(pre_top_cell_shape_elem);
786 ShapeVector new_shape;
787 auto has_unknown = GetUnknownShape(cur_shape_elem->shape(), pre_top_cell_shape_elem->shape(), &new_shape);
788 if (has_unknown) {
789 (void)shape_ptr_list.emplace_back(std::make_shared<abstract::Shape>(new_shape));
790 }
791 }
792 if (shape_ptr_list.size() == cur_shape_size) {
793 (void)args_unknown_shape.emplace_back(std::make_shared<abstract::TupleShape>(shape_ptr_list));
794 }
795 } else {
796 MS_LOG(DEBUG) << "The " << i << "th args shape type is not the same, cur is " << cur_shape->ToString()
797 << " and the elem is " << pre_top_cell_shape->ToString();
798 return false;
799 }
800 }
801 if (args_unknown_shape.size() == args_shape.size()) {
802 ChangeTopCellToUnknownShape(cur_top_cell, args_unknown_shape);
803 return true;
804 }
805 return false;
806 }
807 } // namespace pynative
808 } // namespace mindspore
809