• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import inspect
17from functools import wraps
18from mindspore import nn
19import mindspore as ms
20from mindspore import Tensor, jit, JitConfig
21import numpy as np
22
23ms.set_context(jit_syntax_level=ms.STRICT)
24
25
26class Net(nn.Cell):
27    def __init__(self, func):
28        super().__init__()
29        self.func = func
30
31    def construct(self, *inputs, **kwargs):
32        return self.func(*inputs, **kwargs)
33
34
35def run_with_cell(fn):
36    if fn is None:
37        raise ValueError("fn cannot be none!")
38
39    @wraps(fn)
40    def wrapper(*args, **kwargs):
41        cell_obj = Net(fn)
42        return cell_obj(*args, **kwargs)
43
44    return wrapper
45
46
47def run_with_mode(fn):
48    if fn is None:
49        raise ValueError("fn cannot be none!")
50
51    @wraps(fn)
52    def wrapper(*args, **kwargs):
53        if 'mode' not in kwargs:
54            raise ValueError("mode not provided.")
55        mode = kwargs['mode'].lower()
56        if mode not in ['pynative', 'graph', 'kbk']:
57            raise ValueError(
58                "Invalid mode. Available option: ['pynative', 'graph', 'kbk'].")
59
60        del kwargs['mode']
61        if mode == "graph":
62            return (jit(fn, jit_config=JitConfig(jit_level="O2")))(*args, **kwargs)
63        if mode == "kbk":
64            return (jit(fn, jit_config=JitConfig(jit_level="O0")))(*args, **kwargs)
65        return fn(*args, **kwargs)
66
67    setattr(wrapper, "__wrapped_with_mode__", True)
68    return wrapper
69
70
71def run_with_cell_ext(jit_config=None):
72    def cell_wrap_fn(fn):
73        if fn is None:
74            raise ValueError("fn cannot be none!")
75
76        @wraps(fn)
77        def wrapper(*args, **kwargs):
78            cell_obj = Net(fn)
79            if jit_config:
80                cell_obj.set_jit_config(jit_config)
81            return cell_obj(*args, **kwargs)
82
83        return wrapper
84
85    return cell_wrap_fn
86
87
88def to_cell_obj(fn):
89    cell_obj = Net(fn)
90    return cell_obj
91
92
93def compare(output, expect):
94    '''
95    :param output: Tensor, including tuple/list of tensor
96    :param expect: Numpy array, including tuple/list of Numpy array
97    :return:
98    '''
99    if isinstance(output, (tuple, list)):
100        for o_ele, e_ele in zip(output, expect):
101            compare(o_ele, e_ele)
102    else:
103        if expect.dtype == np.float32:
104            rtol, atol = 1e-4, 1e-4
105        else:
106            rtol, atol = 1e-3, 1e-3
107        if not np.allclose(output.asnumpy(), expect, rtol, atol, equal_nan=True):
108            raise ValueError(f"compare failed \n output: {output.asnumpy()}\n expect: {expect}")
109
110
111def get_inputs_np(shapes, dtypes):
112    np.random.seed(10)
113    inputs_np = []
114    for shape, dtype in zip(shapes, dtypes):
115        inputs_np.append(np.random.randn(*shape).astype(dtype))
116    return inputs_np
117
118
119def get_inputs_tensor(inputs_np):
120    inputs = []
121    for input_np in inputs_np:
122        inputs.append(Tensor(input_np))
123    return inputs
124
125
126def need_run_graph_op_mode(func, args, kwargs):
127    if ms.get_context('device_target') != 'Ascend':
128        return False
129
130    # get description of function params expected
131    sig = inspect.signature(func)
132    sig_args = [param.name for param in sig.parameters.values()]
133
134    mode = None
135    if isinstance(kwargs, dict):
136        for key in ['mode', 'context_mode']:
137            if key in sig_args and key in kwargs:
138                mode = kwargs[key]
139                break
140
141    return mode == ms.GRAPH_MODE
142
143
144def run_test_with_On(test_func):
145
146    @wraps(test_func)
147    def wrapper(*args, **kwargs):
148        # call original test function
149        test_func(*args, **kwargs)
150
151        if not need_run_graph_op_mode(test_func, args, kwargs):
152            return
153
154        org_jit_level = ms.get_context('jit_level')
155        try:
156            # run graph in kernel by kernel mode
157            ms.set_context(jit_level='O0')
158            test_func(*args, **kwargs)
159        finally:
160            ms.set_context(jit_level=org_jit_level)
161
162    return wrapper
163