• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
2#
3# Copyright 2020-2023 Huawei Technologies Co., Ltd
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# ============================================================================
17"""Providing auto dynamic shape interface methods."""
18
19import os
20from mindspore import log as logger
21from mindspore._c_expression import GraphExecutor_, Tensor
22from mindspore.common._utils import is_shape_unknown, is_dim_unknown
23from mindspore.common.parameter import Parameter
24
25SHAPE_DIM_ANY = -1
26SHAPE_RANK_ANY = -2
27
28auto_dynamic_shepe_dict = {}
29
30
31class _AutoDynamicShapeManager:
32    """
33    Represents a function to manage auto identify dynamic shape.
34    """
35    def __init__(self):
36        self.real_shape_cache = []
37        self.generalize_shape_cache = []
38        self.real_phase_and_compile_args_dict = {}
39        self.generalize_phase_and_compile_args_dict = {}
40        self._graph_executor = GraphExecutor_.get_instance()
41
42
43    def __del__(self):
44        self.real_shape_cache = []
45        self.generalize_shape_cache = []
46        self.real_phase_and_compile_args_dict = {}
47        self.generalize_phase_and_compile_args_dict = {}
48
49
50    @staticmethod
51    def is_tensor_equal(input_elem, cache_elem):
52        """check two tensor or param is equal"""
53        if input_elem.shape == cache_elem.shape and input_elem.dtype == cache_elem.dtype:
54            return True
55        return False
56
57
58    @staticmethod
59    def _get_input_generalize_number(arg_list, is_shape_input):
60        """check two tensor or param is equal"""
61        count = 0
62        if is_shape_input:
63            for arg in arg_list:
64                if is_shape_unknown(arg):
65                    count = count + 1
66        else:
67            for arg in arg_list:
68                if isinstance(arg, Tensor) and is_shape_unknown(arg.shape):
69                    count = count + 1
70        return count
71
72
73    def get_real_shape_cache_number(self):
74        """get real shape cache number"""
75        return len(self.real_shape_cache)
76
77
78    def get_real_shape_cache(self):
79        """get real shape cache"""
80        return self.real_shape_cache
81
82
83    def get_generalize_shape_cache_number(self):
84        """get generalize shape cache number"""
85        return len(self.generalize_shape_cache)
86
87
88    def get_generalize_shape_cache(self):
89        """get generalize shape cache"""
90        return self.generalize_shape_cache
91
92
93    def get_cache_by_type(self, cache_type):
94        """get cache by type"""
95        if cache_type == "real":
96            shape_cache = self.real_shape_cache
97        else:
98            shape_cache = self.generalize_shape_cache
99
100        return shape_cache
101
102
103    def get_compile_args_shape_without_sink(self, input_args, res_shape):
104        """get compile args shape with out sink mode"""
105        for arg in input_args:
106            if isinstance(arg, Tensor):
107                res_shape.append(arg.shape)
108            elif isinstance(arg, (int, float)):
109                res_shape.append([])
110            elif isinstance(arg, (tuple, list)):
111                tmp_shape = []
112                self.get_compile_args_shape_without_sink(arg, tmp_shape)
113                res_shape.append(tmp_shape)
114
115
116    def get_compile_args_shape(self, input_args, is_sink_mode):
117        """get compile args shape"""
118        if is_sink_mode:
119            return input_args
120
121        res_shape = []
122        self.get_compile_args_shape_without_sink(input_args, res_shape)
123        return res_shape
124
125
126    def find_compile_args_in_shape_cache(self, input_args, cache_type):
127        """find compile args in real or generalize shape cache"""
128        shape_cache = self.get_cache_by_type(cache_type)
129        for cache_args in shape_cache:
130            res = self._compare_input_args_and_cache_args(input_args, cache_args)
131            if res:
132                return True
133        return False
134
135
136    def update_phase_and_compile_args(self, compile_args, phase, save_cache_number, is_sink_mode, aux=None):
137        """update compile args and phase"""
138
139        if phase in self.real_phase_and_compile_args_dict:
140            logger.debug(f'phase=%r is in real phase and compile args dict.', phase)
141            return
142
143        if phase in self.generalize_phase_and_compile_args_dict:
144            logger.debug(f'phase=%r is in generalize phase and compile args dict.', phase)
145            return
146
147        if len(self.real_shape_cache) < 2:
148            logger.debug(f'The real shape cache number is {len(self.real_shape_cache)}, is less than 2,'
149                         f'phase=%r should be saved in real shape cache.', phase)
150            self.real_phase_and_compile_args_dict[phase] = compile_args
151            self.real_shape_cache.append(compile_args)
152            return
153
154        max_save_cache_number = save_cache_number - 2
155
156        if len(self.generalize_phase_and_compile_args_dict) >= max_save_cache_number:
157            # step1: find delete phase
158            phase_list = list(self.generalize_phase_and_compile_args_dict.keys())
159            delete_phase = phase_list[0]
160            delete_compile_args = self.generalize_phase_and_compile_args_dict.get(delete_phase)
161
162            # step2: delete phase cache
163            if is_sink_mode:
164                if hasattr(aux, '__network_manage__') and delete_phase in aux.__network_manage__:
165                    del aux.__network_manage__[delete_phase]
166                    del self.generalize_phase_and_compile_args_dict[delete_phase]
167            else:
168                delete_cache = set()
169                delete_cache.add(delete_phase)
170                self._graph_executor.del_net_res(None, delete_cache)
171                del self.generalize_phase_and_compile_args_dict[delete_phase]
172
173            # step3: delete compile args
174            self.generalize_shape_cache.remove(delete_compile_args)
175
176        # step3 save phase and compile args into cache
177        logger.info(f'The generalize shape cache number is {len(self.generalize_shape_cache)}, is less than '
178                    f'{max_save_cache_number}, phase=%r should be saved in generalize shape cache.', phase)
179        self.generalize_phase_and_compile_args_dict[phase] = compile_args
180        self.generalize_shape_cache.append(compile_args)
181
182
183    def _compare_input_args_and_cache_args(self, input_args, cache_args):
184        """compare input args and cache args"""
185        for (arg, cache) in zip(input_args, cache_args):
186            if isinstance(arg, Tensor) and isinstance(cache, Tensor):
187                if not self.is_tensor_equal(arg, cache):
188                    return False
189            elif isinstance(arg, int) and isinstance(cache, int):
190                if arg != cache:
191                    return False
192            elif isinstance(arg, (tuple, list)) and isinstance(cache, (tuple, list)):
193                if not self._compare_input_args_and_cache_args(arg, cache):
194                    return False
195        return True
196
197
198class _AutoIdentifyDynamicShape:
199    """
200    Represents a function auto identify dynamic shape.
201    """
202    def __init__(self):
203        self.all_shape_cache = {}
204        self.is_sink_mode = False
205        self.is_enable_auto_dynamic_shape = True
206        self.save_cache_number = 3
207        self.enable_auto_identify = os.getenv('MS_AUTO_DYNAMIC_SHAPE_ENABLE')
208        self.auto_dynamic_shape_manager = _AutoDynamicShapeManager()
209
210
211    def __del__(self):
212        self.all_shape_cache = {}
213        self.is_sink_mode = False
214        self.is_enable_auto_dynamic_shape = True
215        self.save_cache_number = 3
216
217
218    def _check_input_args_number(self, args_list):
219        """check input arg number"""
220        if self.auto_dynamic_shape_manager.get_real_shape_cache_number() > 0:
221            first_real_cache = self.auto_dynamic_shape_manager.get_real_shape_cache()[0]
222            if len(first_real_cache) != len(args_list):
223                return False
224        return True
225
226
227    def _check_input_tensor_type(self, args_list, cache_list):
228        """check input args type"""
229        for (arg, cache) in zip(args_list, cache_list):
230            if isinstance(arg, Tensor) and isinstance(cache, Tensor):
231                if arg.dtype != cache.dtype:
232                    logger.debug((f'input tensor type = {arg.dtype}, cache tensor type = {cache.dtype}, '
233                                  f'tensor types are not same.'))
234                    return False
235            elif isinstance(arg, (tuple, list)) and isinstance(cache, (tuple, list)):
236                res = self._check_input_tensor_type(arg, cache)
237                if not res:
238                    return False
239            elif (isinstance(arg, int) and isinstance(cache, int)) or \
240                 (isinstance(arg, float) and isinstance(cache, float)):
241                if arg != cache:
242                    return False
243            elif isinstance(arg, Tensor) and not isinstance(cache, Tensor):
244                return False
245            elif isinstance(arg, (int, float)) and not isinstance(cache, (int, float)):
246                return False
247            elif isinstance(arg, (tuple, list)) and not isinstance(cache, (tuple, list)):
248                return False
249        return True
250
251
252    def _check_input_number_and_type(self, args_list):
253        """check input number and type"""
254        res = self._check_input_args_number(args_list)
255        if not res:
256            return False
257
258        if self.auto_dynamic_shape_manager.get_real_shape_cache_number() > 0:
259            cache_list = self.auto_dynamic_shape_manager.get_real_shape_cache()[0]
260            res = self._check_input_tensor_type(args_list, cache_list)
261            if not res:
262                return False
263        return True
264
265
266    def _is_enable_auto_dynamic_shape(self, args_list, is_sink_mode):
267        """is enable auto identify shape"""
268        if not is_sink_mode and not args_list:
269            return False
270
271        if not self.enable_auto_identify:
272            self.enable_auto_identify = "0"
273
274        if self.enable_auto_identify == "0":
275            return False
276
277        for elem in args_list:
278            if elem is None:
279                continue
280            if not isinstance(elem, (list, tuple, Tensor, int, float)):
281                return False
282            if isinstance(elem, Tensor) and (is_shape_unknown(elem.shape) or (not elem.shape)):
283                return False
284            if not is_sink_mode and isinstance(elem, (list, tuple)):
285                return self._is_enable_auto_dynamic_shape(elem, is_sink_mode)
286        return True
287
288
289    @staticmethod
290    def _do_generalize_in_sink(input_arg, cache, input_index, cache_index, cache_type):
291        """do generalize in sink, input rank must be 2"""
292        if not input_arg:
293            raise ValueError("In sink mode, cell input can not be scalar.")
294
295        if input_arg == cache:
296            return cache
297
298        shape_value = []
299        if len(input_arg) != len(cache):
300            shape_value.append(SHAPE_RANK_ANY)
301        else:
302            for _ in input_arg:
303                shape_value.append(SHAPE_DIM_ANY)
304        logger.info((f'In the {cache_type} cache[{cache_index}], the {input_index}th input tensor shape is {input_arg},'
305                     f'cache shape is {cache}, not equal, need generalize to {shape_value}.'))
306        return shape_value
307
308    def update_phase_and_compile_args(self, args, phase, is_sink_mode, aux=None):
309        """save compile args and phase into dict"""
310        if not self.is_enable_auto_dynamic_shape:
311            return
312        self.auto_dynamic_shape_manager.update_phase_and_compile_args(args, phase, self.save_cache_number,
313                                                                      is_sink_mode, aux)
314
315    def _check_real_shape_cache(self, res_shape, args_list):
316        """find cache in real_shape_cache"""
317        real_cache_number = self.auto_dynamic_shape_manager.get_real_shape_cache_number()
318        if real_cache_number < 2:
319            logger.info((f'real shape cache cap is {real_cache_number}, smaller than 2, '
320                         f'compile args shape={res_shape}.'))
321            return True
322
323        is_real_shape_exist = self.auto_dynamic_shape_manager.find_compile_args_in_shape_cache(args_list, "real")
324        if is_real_shape_exist:
325            logger.debug((f'find compile args in real shape cache, compile args shape={res_shape}'))
326            return True
327
328        return False
329
330    def _generate_with_generalize_shape(self, generalize_shape_args, is_sink_mode, args_list):
331        """generate with generalize_shape """
332        new_generalize_shape, can_generalize = self._do_generalize_shape("generalize", generalize_shape_args,
333                                                                         is_sink_mode)
334        if not can_generalize:
335            return args_list
336
337        res_shape = self.auto_dynamic_shape_manager.get_compile_args_shape(new_generalize_shape, is_sink_mode)
338        logger.info((f'generalize with generalize shape cache, compile args shape = {res_shape}'))
339        return new_generalize_shape
340
341    def auto_dynamic_generate_compile_args(self, args_list, is_sink_mode):
342        """generate compile args in auto dynamic shape"""
343        if not self._check_input_number_and_type(args_list) or \
344            not self._is_enable_auto_dynamic_shape(args_list, is_sink_mode):
345            self.is_enable_auto_dynamic_shape = False
346            return args_list
347        self.is_sink_mode = is_sink_mode
348
349        res_shape = self.auto_dynamic_shape_manager.get_compile_args_shape(args_list, is_sink_mode)
350        logger.debug((f'input args list shape = {res_shape}.'))
351
352        # step1: find cache in real_shape_cache.
353        if self._check_real_shape_cache(res_shape, args_list):
354            return args_list
355
356        # step2: if can not find cache in real_shape_cache, then generate it
357        generalize_shape_args, can_generalize = self._do_generalize_shape("real", args_list, is_sink_mode)
358        if not can_generalize:
359            return args_list
360
361        if self.auto_dynamic_shape_manager.get_generalize_shape_cache_number() == 0:
362            res_shape = self.auto_dynamic_shape_manager.get_compile_args_shape(generalize_shape_args, is_sink_mode)
363            logger.info((f'generalize shape cache cap is smaller than 1, compile args shape = {res_shape}.'))
364            return generalize_shape_args
365
366        # step3: find generalize_shape in generalize_shape_cache
367        is_generalize_shape_exist = \
368            self.auto_dynamic_shape_manager.find_compile_args_in_shape_cache(generalize_shape_args, "generalize")
369
370        # step 4: if can not find cache in generalize_shape_cache, then generate it again
371        if not is_generalize_shape_exist:
372            return self._generate_with_generalize_shape(generalize_shape_args, is_sink_mode, args_list)
373
374        res_shape = self.auto_dynamic_shape_manager.get_compile_args_shape(generalize_shape_args, is_sink_mode)
375        logger.debug((f'find compile args in generalize shape cache, compile args shape={res_shape}'))
376        return generalize_shape_args
377
378
379    def _cal_unknown_shape_count(self, generalize_shape_args, is_sink_mode):
380        """generalize shape by compare with real shape cache, and return the least generalize input."""
381        unknown_shape_count = 0
382        unknown_rank_count = 0
383        for elem in generalize_shape_args:
384            if isinstance(elem, (list, tuple)):
385                rank_count, shape_count = self._cal_unknown_shape_count(elem, is_sink_mode)
386                unknown_rank_count = unknown_rank_count + rank_count
387                unknown_shape_count = unknown_shape_count + shape_count
388            if isinstance(elem, Tensor):
389                if is_shape_unknown(elem.shape):
390                    unknown_shape_count = unknown_shape_count + 1
391                if is_dim_unknown(elem.shape):
392                    unknown_rank_count = unknown_rank_count + 1
393            if is_sink_mode and isinstance(elem, int):
394                if elem == SHAPE_DIM_ANY:
395                    unknown_shape_count = unknown_shape_count + 1
396                if elem == SHAPE_RANK_ANY:
397                    unknown_rank_count = unknown_rank_count + 1
398        return unknown_rank_count, unknown_shape_count
399
400
401    def _do_generalize_one_input_shape(self, input_args, cache_args, cache_type, index, is_sink_mode):
402        """do generalize shape one input by cache"""
403        def generalize_tensor(arg, cache, i):
404            if self.auto_dynamic_shape_manager.is_tensor_equal(arg, cache):
405                return arg
406
407            shape_value = []
408            if len(arg.shape) != len(cache.shape):
409                shape_value.append(SHAPE_RANK_ANY)
410            else:
411                shape_value = [SHAPE_DIM_ANY for _ in range(len(arg.shape))]
412            shape_tuple = tuple(shape_value)
413            logger.info((f'In the {cache_type} cache[{index}], the {i}th input tensor shape is {arg.shape},'
414                         f'cache shape is {cache.shape}, not equal, need generalize to {shape_tuple}.'))
415            return Tensor(shape=shape_tuple, dtype=arg.dtype)
416
417        def generalize_sequence(arg, cache, i):
418            if is_sink_mode:
419                # when is_sink_mode=True, input must be the shape of Tensor.
420                res = self._do_generalize_in_sink(arg, cache, i, index, cache_type)
421                return res
422
423            res = self._do_generalize_one_input_shape(arg, cache, cache_type, index, is_sink_mode)
424            return res
425
426        generalize_one_shape = []
427        for i, (arg, cache) in enumerate(zip(input_args, cache_args)):
428            if isinstance(arg, Parameter) and isinstance(cache, Parameter):
429                if self.auto_dynamic_shape_manager.is_tensor_equal(arg, cache):
430                    generalize_one_shape.append(arg)
431                    continue
432
433                logger.info("In auto dynamic shape mode, parameter must be equal, it can not be generalize.")
434                return input_args, False
435
436            if isinstance(arg, Tensor) and isinstance(cache, Tensor):
437                res = generalize_tensor(arg, cache, i)
438                generalize_one_shape.append(res)
439            elif isinstance(arg, (tuple, list)) and isinstance(cache, (tuple, list)):
440                res = generalize_sequence(arg, cache, i)
441                generalize_one_shape.append(res)
442            elif isinstance(arg, int) and isinstance(cache, int):
443                # when is_sink_mode=False, the input must may be scalar, or the value of list/tuple.
444                # is_sink_mode can not be True
445                if arg == cache:
446                    generalize_one_shape.append(arg)
447                else:
448                    logger.info("In auto dynamic shape mode, scalar/tuple/list must be equal, it can not be " \
449                                "generalize.")
450                    return input_args, False
451            elif arg is None and cache is None:
452                generalize_one_shape.append(arg)
453
454        return generalize_one_shape, True
455
456
457    def _do_generalize_shape(self, cache_type, input_args, is_sink_mode):
458        """do generalize shape by cache"""
459        shape_cache = self.auto_dynamic_shape_manager.get_cache_by_type(cache_type)
460        all_generalize_shape_args = []
461        for index, cache_args in enumerate(shape_cache):
462            generalize_shape, can_generalize = self._do_generalize_one_input_shape(input_args, cache_args, cache_type,
463                                                                                   index, is_sink_mode)
464            if not can_generalize:
465                return generalize_shape, False
466            all_generalize_shape_args.append(tuple(generalize_shape))
467
468        unknown_shape_dict = {}
469        for generalize_shape_args in all_generalize_shape_args:
470            unknown_rank_count, unknown_shape_count = self._cal_unknown_shape_count(generalize_shape_args, is_sink_mode)
471            unknown_count = (unknown_rank_count, unknown_shape_count)
472            if unknown_count not in unknown_shape_dict:
473                unknown_shape_dict[unknown_count] = generalize_shape_args
474
475        keys = list(unknown_shape_dict.keys())
476        keys.sort(key=lambda x: (x[0], x[1]))
477        return unknown_shape_dict.get(keys[0]), True
478
479_auto_dynamic_shape = _AutoIdentifyDynamicShape()
480
481
482def get_auto_dynamic_shape_args(compile_args, key_id):
483    """get auto dynamic shape args."""
484    if key_id not in auto_dynamic_shepe_dict:
485        auto_dynamic_shepe_dict[key_id] = _AutoIdentifyDynamicShape()
486    compile_args = auto_dynamic_shepe_dict[key_id].auto_dynamic_generate_compile_args(compile_args, False)
487    return compile_args
488
489
490def update_auto_dynamic_shape_phase(compile_args, key_id, phase):
491    """update auto dynamic shape phase."""
492    if key_id in auto_dynamic_shepe_dict:
493        auto_dynamic_shepe_dict[key_id].update_phase_and_compile_args(compile_args, phase, False)
494
495
496def get_auto_dynamic_shape_args_with_check_input_signature(compile_args, key_id, input_signature):
497    """get auto dynamic shape args."""
498    if input_signature is None:
499        return get_auto_dynamic_shape_args(compile_args, key_id)
500    return compile_args
501
502
503def update_auto_dynamic_shape_phase_with_check_input_signature(compile_args, key_id, phase, input_signature):
504    """update auto dynamic shape phase."""
505    if input_signature is None:
506        if key_id in auto_dynamic_shepe_dict:
507            auto_dynamic_shepe_dict[key_id].update_phase_and_compile_args(compile_args, phase, False)
508