1# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 2# 3# Copyright 2020-2023 Huawei Technologies Co., Ltd 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# ============================================================================ 17"""Providing auto dynamic shape interface methods.""" 18 19import os 20from mindspore import log as logger 21from mindspore._c_expression import GraphExecutor_, Tensor 22from mindspore.common._utils import is_shape_unknown, is_dim_unknown 23from mindspore.common.parameter import Parameter 24 25SHAPE_DIM_ANY = -1 26SHAPE_RANK_ANY = -2 27 28auto_dynamic_shepe_dict = {} 29 30 31class _AutoDynamicShapeManager: 32 """ 33 Represents a function to manage auto identify dynamic shape. 34 """ 35 def __init__(self): 36 self.real_shape_cache = [] 37 self.generalize_shape_cache = [] 38 self.real_phase_and_compile_args_dict = {} 39 self.generalize_phase_and_compile_args_dict = {} 40 self._graph_executor = GraphExecutor_.get_instance() 41 42 43 def __del__(self): 44 self.real_shape_cache = [] 45 self.generalize_shape_cache = [] 46 self.real_phase_and_compile_args_dict = {} 47 self.generalize_phase_and_compile_args_dict = {} 48 49 50 @staticmethod 51 def is_tensor_equal(input_elem, cache_elem): 52 """check two tensor or param is equal""" 53 if input_elem.shape == cache_elem.shape and input_elem.dtype == cache_elem.dtype: 54 return True 55 return False 56 57 58 @staticmethod 59 def _get_input_generalize_number(arg_list, is_shape_input): 60 """check two tensor or param is equal""" 61 count = 0 62 if is_shape_input: 63 for arg in arg_list: 64 if is_shape_unknown(arg): 65 count = count + 1 66 else: 67 for arg in arg_list: 68 if isinstance(arg, Tensor) and is_shape_unknown(arg.shape): 69 count = count + 1 70 return count 71 72 73 def get_real_shape_cache_number(self): 74 """get real shape cache number""" 75 return len(self.real_shape_cache) 76 77 78 def get_real_shape_cache(self): 79 """get real shape cache""" 80 return self.real_shape_cache 81 82 83 def get_generalize_shape_cache_number(self): 84 """get generalize shape cache number""" 85 return len(self.generalize_shape_cache) 86 87 88 def get_generalize_shape_cache(self): 89 """get generalize shape cache""" 90 return self.generalize_shape_cache 91 92 93 def get_cache_by_type(self, cache_type): 94 """get cache by type""" 95 if cache_type == "real": 96 shape_cache = self.real_shape_cache 97 else: 98 shape_cache = self.generalize_shape_cache 99 100 return shape_cache 101 102 103 def get_compile_args_shape_without_sink(self, input_args, res_shape): 104 """get compile args shape with out sink mode""" 105 for arg in input_args: 106 if isinstance(arg, Tensor): 107 res_shape.append(arg.shape) 108 elif isinstance(arg, (int, float)): 109 res_shape.append([]) 110 elif isinstance(arg, (tuple, list)): 111 tmp_shape = [] 112 self.get_compile_args_shape_without_sink(arg, tmp_shape) 113 res_shape.append(tmp_shape) 114 115 116 def get_compile_args_shape(self, input_args, is_sink_mode): 117 """get compile args shape""" 118 if is_sink_mode: 119 return input_args 120 121 res_shape = [] 122 self.get_compile_args_shape_without_sink(input_args, res_shape) 123 return res_shape 124 125 126 def find_compile_args_in_shape_cache(self, input_args, cache_type): 127 """find compile args in real or generalize shape cache""" 128 shape_cache = self.get_cache_by_type(cache_type) 129 for cache_args in shape_cache: 130 res = self._compare_input_args_and_cache_args(input_args, cache_args) 131 if res: 132 return True 133 return False 134 135 136 def update_phase_and_compile_args(self, compile_args, phase, save_cache_number, is_sink_mode, aux=None): 137 """update compile args and phase""" 138 139 if phase in self.real_phase_and_compile_args_dict: 140 logger.debug(f'phase=%r is in real phase and compile args dict.', phase) 141 return 142 143 if phase in self.generalize_phase_and_compile_args_dict: 144 logger.debug(f'phase=%r is in generalize phase and compile args dict.', phase) 145 return 146 147 if len(self.real_shape_cache) < 2: 148 logger.debug(f'The real shape cache number is {len(self.real_shape_cache)}, is less than 2,' 149 f'phase=%r should be saved in real shape cache.', phase) 150 self.real_phase_and_compile_args_dict[phase] = compile_args 151 self.real_shape_cache.append(compile_args) 152 return 153 154 max_save_cache_number = save_cache_number - 2 155 156 if len(self.generalize_phase_and_compile_args_dict) >= max_save_cache_number: 157 # step1: find delete phase 158 phase_list = list(self.generalize_phase_and_compile_args_dict.keys()) 159 delete_phase = phase_list[0] 160 delete_compile_args = self.generalize_phase_and_compile_args_dict.get(delete_phase) 161 162 # step2: delete phase cache 163 if is_sink_mode: 164 if hasattr(aux, '__network_manage__') and delete_phase in aux.__network_manage__: 165 del aux.__network_manage__[delete_phase] 166 del self.generalize_phase_and_compile_args_dict[delete_phase] 167 else: 168 delete_cache = set() 169 delete_cache.add(delete_phase) 170 self._graph_executor.del_net_res(None, delete_cache) 171 del self.generalize_phase_and_compile_args_dict[delete_phase] 172 173 # step3: delete compile args 174 self.generalize_shape_cache.remove(delete_compile_args) 175 176 # step3 save phase and compile args into cache 177 logger.info(f'The generalize shape cache number is {len(self.generalize_shape_cache)}, is less than ' 178 f'{max_save_cache_number}, phase=%r should be saved in generalize shape cache.', phase) 179 self.generalize_phase_and_compile_args_dict[phase] = compile_args 180 self.generalize_shape_cache.append(compile_args) 181 182 183 def _compare_input_args_and_cache_args(self, input_args, cache_args): 184 """compare input args and cache args""" 185 for (arg, cache) in zip(input_args, cache_args): 186 if isinstance(arg, Tensor) and isinstance(cache, Tensor): 187 if not self.is_tensor_equal(arg, cache): 188 return False 189 elif isinstance(arg, int) and isinstance(cache, int): 190 if arg != cache: 191 return False 192 elif isinstance(arg, (tuple, list)) and isinstance(cache, (tuple, list)): 193 if not self._compare_input_args_and_cache_args(arg, cache): 194 return False 195 return True 196 197 198class _AutoIdentifyDynamicShape: 199 """ 200 Represents a function auto identify dynamic shape. 201 """ 202 def __init__(self): 203 self.all_shape_cache = {} 204 self.is_sink_mode = False 205 self.is_enable_auto_dynamic_shape = True 206 self.save_cache_number = 3 207 self.enable_auto_identify = os.getenv('MS_AUTO_DYNAMIC_SHAPE_ENABLE') 208 self.auto_dynamic_shape_manager = _AutoDynamicShapeManager() 209 210 211 def __del__(self): 212 self.all_shape_cache = {} 213 self.is_sink_mode = False 214 self.is_enable_auto_dynamic_shape = True 215 self.save_cache_number = 3 216 217 218 def _check_input_args_number(self, args_list): 219 """check input arg number""" 220 if self.auto_dynamic_shape_manager.get_real_shape_cache_number() > 0: 221 first_real_cache = self.auto_dynamic_shape_manager.get_real_shape_cache()[0] 222 if len(first_real_cache) != len(args_list): 223 return False 224 return True 225 226 227 def _check_input_tensor_type(self, args_list, cache_list): 228 """check input args type""" 229 for (arg, cache) in zip(args_list, cache_list): 230 if isinstance(arg, Tensor) and isinstance(cache, Tensor): 231 if arg.dtype != cache.dtype: 232 logger.debug((f'input tensor type = {arg.dtype}, cache tensor type = {cache.dtype}, ' 233 f'tensor types are not same.')) 234 return False 235 elif isinstance(arg, (tuple, list)) and isinstance(cache, (tuple, list)): 236 res = self._check_input_tensor_type(arg, cache) 237 if not res: 238 return False 239 elif (isinstance(arg, int) and isinstance(cache, int)) or \ 240 (isinstance(arg, float) and isinstance(cache, float)): 241 if arg != cache: 242 return False 243 elif isinstance(arg, Tensor) and not isinstance(cache, Tensor): 244 return False 245 elif isinstance(arg, (int, float)) and not isinstance(cache, (int, float)): 246 return False 247 elif isinstance(arg, (tuple, list)) and not isinstance(cache, (tuple, list)): 248 return False 249 return True 250 251 252 def _check_input_number_and_type(self, args_list): 253 """check input number and type""" 254 res = self._check_input_args_number(args_list) 255 if not res: 256 return False 257 258 if self.auto_dynamic_shape_manager.get_real_shape_cache_number() > 0: 259 cache_list = self.auto_dynamic_shape_manager.get_real_shape_cache()[0] 260 res = self._check_input_tensor_type(args_list, cache_list) 261 if not res: 262 return False 263 return True 264 265 266 def _is_enable_auto_dynamic_shape(self, args_list, is_sink_mode): 267 """is enable auto identify shape""" 268 if not is_sink_mode and not args_list: 269 return False 270 271 if not self.enable_auto_identify: 272 self.enable_auto_identify = "0" 273 274 if self.enable_auto_identify == "0": 275 return False 276 277 for elem in args_list: 278 if elem is None: 279 continue 280 if not isinstance(elem, (list, tuple, Tensor, int, float)): 281 return False 282 if isinstance(elem, Tensor) and (is_shape_unknown(elem.shape) or (not elem.shape)): 283 return False 284 if not is_sink_mode and isinstance(elem, (list, tuple)): 285 return self._is_enable_auto_dynamic_shape(elem, is_sink_mode) 286 return True 287 288 289 @staticmethod 290 def _do_generalize_in_sink(input_arg, cache, input_index, cache_index, cache_type): 291 """do generalize in sink, input rank must be 2""" 292 if not input_arg: 293 raise ValueError("In sink mode, cell input can not be scalar.") 294 295 if input_arg == cache: 296 return cache 297 298 shape_value = [] 299 if len(input_arg) != len(cache): 300 shape_value.append(SHAPE_RANK_ANY) 301 else: 302 for _ in input_arg: 303 shape_value.append(SHAPE_DIM_ANY) 304 logger.info((f'In the {cache_type} cache[{cache_index}], the {input_index}th input tensor shape is {input_arg},' 305 f'cache shape is {cache}, not equal, need generalize to {shape_value}.')) 306 return shape_value 307 308 def update_phase_and_compile_args(self, args, phase, is_sink_mode, aux=None): 309 """save compile args and phase into dict""" 310 if not self.is_enable_auto_dynamic_shape: 311 return 312 self.auto_dynamic_shape_manager.update_phase_and_compile_args(args, phase, self.save_cache_number, 313 is_sink_mode, aux) 314 315 def _check_real_shape_cache(self, res_shape, args_list): 316 """find cache in real_shape_cache""" 317 real_cache_number = self.auto_dynamic_shape_manager.get_real_shape_cache_number() 318 if real_cache_number < 2: 319 logger.info((f'real shape cache cap is {real_cache_number}, smaller than 2, ' 320 f'compile args shape={res_shape}.')) 321 return True 322 323 is_real_shape_exist = self.auto_dynamic_shape_manager.find_compile_args_in_shape_cache(args_list, "real") 324 if is_real_shape_exist: 325 logger.debug((f'find compile args in real shape cache, compile args shape={res_shape}')) 326 return True 327 328 return False 329 330 def _generate_with_generalize_shape(self, generalize_shape_args, is_sink_mode, args_list): 331 """generate with generalize_shape """ 332 new_generalize_shape, can_generalize = self._do_generalize_shape("generalize", generalize_shape_args, 333 is_sink_mode) 334 if not can_generalize: 335 return args_list 336 337 res_shape = self.auto_dynamic_shape_manager.get_compile_args_shape(new_generalize_shape, is_sink_mode) 338 logger.info((f'generalize with generalize shape cache, compile args shape = {res_shape}')) 339 return new_generalize_shape 340 341 def auto_dynamic_generate_compile_args(self, args_list, is_sink_mode): 342 """generate compile args in auto dynamic shape""" 343 if not self._check_input_number_and_type(args_list) or \ 344 not self._is_enable_auto_dynamic_shape(args_list, is_sink_mode): 345 self.is_enable_auto_dynamic_shape = False 346 return args_list 347 self.is_sink_mode = is_sink_mode 348 349 res_shape = self.auto_dynamic_shape_manager.get_compile_args_shape(args_list, is_sink_mode) 350 logger.debug((f'input args list shape = {res_shape}.')) 351 352 # step1: find cache in real_shape_cache. 353 if self._check_real_shape_cache(res_shape, args_list): 354 return args_list 355 356 # step2: if can not find cache in real_shape_cache, then generate it 357 generalize_shape_args, can_generalize = self._do_generalize_shape("real", args_list, is_sink_mode) 358 if not can_generalize: 359 return args_list 360 361 if self.auto_dynamic_shape_manager.get_generalize_shape_cache_number() == 0: 362 res_shape = self.auto_dynamic_shape_manager.get_compile_args_shape(generalize_shape_args, is_sink_mode) 363 logger.info((f'generalize shape cache cap is smaller than 1, compile args shape = {res_shape}.')) 364 return generalize_shape_args 365 366 # step3: find generalize_shape in generalize_shape_cache 367 is_generalize_shape_exist = \ 368 self.auto_dynamic_shape_manager.find_compile_args_in_shape_cache(generalize_shape_args, "generalize") 369 370 # step 4: if can not find cache in generalize_shape_cache, then generate it again 371 if not is_generalize_shape_exist: 372 return self._generate_with_generalize_shape(generalize_shape_args, is_sink_mode, args_list) 373 374 res_shape = self.auto_dynamic_shape_manager.get_compile_args_shape(generalize_shape_args, is_sink_mode) 375 logger.debug((f'find compile args in generalize shape cache, compile args shape={res_shape}')) 376 return generalize_shape_args 377 378 379 def _cal_unknown_shape_count(self, generalize_shape_args, is_sink_mode): 380 """generalize shape by compare with real shape cache, and return the least generalize input.""" 381 unknown_shape_count = 0 382 unknown_rank_count = 0 383 for elem in generalize_shape_args: 384 if isinstance(elem, (list, tuple)): 385 rank_count, shape_count = self._cal_unknown_shape_count(elem, is_sink_mode) 386 unknown_rank_count = unknown_rank_count + rank_count 387 unknown_shape_count = unknown_shape_count + shape_count 388 if isinstance(elem, Tensor): 389 if is_shape_unknown(elem.shape): 390 unknown_shape_count = unknown_shape_count + 1 391 if is_dim_unknown(elem.shape): 392 unknown_rank_count = unknown_rank_count + 1 393 if is_sink_mode and isinstance(elem, int): 394 if elem == SHAPE_DIM_ANY: 395 unknown_shape_count = unknown_shape_count + 1 396 if elem == SHAPE_RANK_ANY: 397 unknown_rank_count = unknown_rank_count + 1 398 return unknown_rank_count, unknown_shape_count 399 400 401 def _do_generalize_one_input_shape(self, input_args, cache_args, cache_type, index, is_sink_mode): 402 """do generalize shape one input by cache""" 403 def generalize_tensor(arg, cache, i): 404 if self.auto_dynamic_shape_manager.is_tensor_equal(arg, cache): 405 return arg 406 407 shape_value = [] 408 if len(arg.shape) != len(cache.shape): 409 shape_value.append(SHAPE_RANK_ANY) 410 else: 411 shape_value = [SHAPE_DIM_ANY for _ in range(len(arg.shape))] 412 shape_tuple = tuple(shape_value) 413 logger.info((f'In the {cache_type} cache[{index}], the {i}th input tensor shape is {arg.shape},' 414 f'cache shape is {cache.shape}, not equal, need generalize to {shape_tuple}.')) 415 return Tensor(shape=shape_tuple, dtype=arg.dtype) 416 417 def generalize_sequence(arg, cache, i): 418 if is_sink_mode: 419 # when is_sink_mode=True, input must be the shape of Tensor. 420 res = self._do_generalize_in_sink(arg, cache, i, index, cache_type) 421 return res 422 423 res = self._do_generalize_one_input_shape(arg, cache, cache_type, index, is_sink_mode) 424 return res 425 426 generalize_one_shape = [] 427 for i, (arg, cache) in enumerate(zip(input_args, cache_args)): 428 if isinstance(arg, Parameter) and isinstance(cache, Parameter): 429 if self.auto_dynamic_shape_manager.is_tensor_equal(arg, cache): 430 generalize_one_shape.append(arg) 431 continue 432 433 logger.info("In auto dynamic shape mode, parameter must be equal, it can not be generalize.") 434 return input_args, False 435 436 if isinstance(arg, Tensor) and isinstance(cache, Tensor): 437 res = generalize_tensor(arg, cache, i) 438 generalize_one_shape.append(res) 439 elif isinstance(arg, (tuple, list)) and isinstance(cache, (tuple, list)): 440 res = generalize_sequence(arg, cache, i) 441 generalize_one_shape.append(res) 442 elif isinstance(arg, int) and isinstance(cache, int): 443 # when is_sink_mode=False, the input must may be scalar, or the value of list/tuple. 444 # is_sink_mode can not be True 445 if arg == cache: 446 generalize_one_shape.append(arg) 447 else: 448 logger.info("In auto dynamic shape mode, scalar/tuple/list must be equal, it can not be " \ 449 "generalize.") 450 return input_args, False 451 elif arg is None and cache is None: 452 generalize_one_shape.append(arg) 453 454 return generalize_one_shape, True 455 456 457 def _do_generalize_shape(self, cache_type, input_args, is_sink_mode): 458 """do generalize shape by cache""" 459 shape_cache = self.auto_dynamic_shape_manager.get_cache_by_type(cache_type) 460 all_generalize_shape_args = [] 461 for index, cache_args in enumerate(shape_cache): 462 generalize_shape, can_generalize = self._do_generalize_one_input_shape(input_args, cache_args, cache_type, 463 index, is_sink_mode) 464 if not can_generalize: 465 return generalize_shape, False 466 all_generalize_shape_args.append(tuple(generalize_shape)) 467 468 unknown_shape_dict = {} 469 for generalize_shape_args in all_generalize_shape_args: 470 unknown_rank_count, unknown_shape_count = self._cal_unknown_shape_count(generalize_shape_args, is_sink_mode) 471 unknown_count = (unknown_rank_count, unknown_shape_count) 472 if unknown_count not in unknown_shape_dict: 473 unknown_shape_dict[unknown_count] = generalize_shape_args 474 475 keys = list(unknown_shape_dict.keys()) 476 keys.sort(key=lambda x: (x[0], x[1])) 477 return unknown_shape_dict.get(keys[0]), True 478 479_auto_dynamic_shape = _AutoIdentifyDynamicShape() 480 481 482def get_auto_dynamic_shape_args(compile_args, key_id): 483 """get auto dynamic shape args.""" 484 if key_id not in auto_dynamic_shepe_dict: 485 auto_dynamic_shepe_dict[key_id] = _AutoIdentifyDynamicShape() 486 compile_args = auto_dynamic_shepe_dict[key_id].auto_dynamic_generate_compile_args(compile_args, False) 487 return compile_args 488 489 490def update_auto_dynamic_shape_phase(compile_args, key_id, phase): 491 """update auto dynamic shape phase.""" 492 if key_id in auto_dynamic_shepe_dict: 493 auto_dynamic_shepe_dict[key_id].update_phase_and_compile_args(compile_args, phase, False) 494 495 496def get_auto_dynamic_shape_args_with_check_input_signature(compile_args, key_id, input_signature): 497 """get auto dynamic shape args.""" 498 if input_signature is None: 499 return get_auto_dynamic_shape_args(compile_args, key_id) 500 return compile_args 501 502 503def update_auto_dynamic_shape_phase_with_check_input_signature(compile_args, key_id, phase, input_signature): 504 """update auto dynamic shape phase.""" 505 if input_signature is None: 506 if key_id in auto_dynamic_shepe_dict: 507 auto_dynamic_shepe_dict[key_id].update_phase_and_compile_args(compile_args, phase, False) 508