• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
2#
3# Copyright 2023 Huawei Technologies Co., Ltd
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# ============================================================================
17"""convert inputs to dynamic shape automatically at the first round. Method in this file is only used for test."""
18import os
19from mindspore.common.parameter import Parameter
20from mindspore.common.tensor import Tensor
21
22
23def is_auto_dynamic_shape():
24    """this is used only for test"""
25    return os.getenv("MS_DEV_AUTO_DYNAMIC_SHAPE") == "on"
26
27
28def is_auto_dynamic_rank():
29    """this is used only for test"""
30    return os.getenv("MS_DEV_AUTO_DYNAMIC_RANK") == "on"
31
32
33def is_auto_dynamic():
34    """this is used only for test"""
35    return is_auto_dynamic_shape() or is_auto_dynamic_rank()
36
37
38def convert_inputs_to_dynamic(*inputs):
39    """this is used only for test"""
40    dyn_inputs = list(inputs)
41    if not dyn_inputs:
42        return None
43    for idx, net_input in enumerate(inputs):
44        if isinstance(net_input, Tensor) and not isinstance(net_input, Parameter):
45            shp = net_input.shape
46            if not shp:
47                dyn_inputs[idx] = net_input
48                continue
49            if is_auto_dynamic_rank():
50                dyn_tensor = Tensor(shape=None, dtype=net_input.dtype)
51            else:
52                dyn_shape = [None for _ in net_input.shape]
53                dyn_tensor = Tensor(shape=dyn_shape, dtype=net_input.dtype)
54            dyn_inputs[idx] = dyn_tensor
55
56    return tuple(dyn_inputs)
57
58
59def convert_new_shapes(dataset_shapes):
60    """this is used only for test"""
61    new_shapes = []
62    for shape in dataset_shapes:
63        if is_auto_dynamic_rank():
64            new_shape = [-2]
65        else:
66            new_shape = [-1 for _ in shape]
67        new_shapes.append(new_shape)
68    return new_shapes
69