1# Copyright 2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import inspect 17from functools import wraps 18from mindspore import nn 19import mindspore as ms 20from mindspore import Tensor, jit, JitConfig 21import numpy as np 22 23ms.set_context(jit_syntax_level=ms.STRICT) 24 25 26class Net(nn.Cell): 27 def __init__(self, func): 28 super().__init__() 29 self.func = func 30 31 def construct(self, *inputs, **kwargs): 32 return self.func(*inputs, **kwargs) 33 34 35def run_with_cell(fn): 36 if fn is None: 37 raise ValueError("fn cannot be none!") 38 39 @wraps(fn) 40 def wrapper(*args, **kwargs): 41 cell_obj = Net(fn) 42 return cell_obj(*args, **kwargs) 43 44 return wrapper 45 46 47def run_with_mode(fn): 48 if fn is None: 49 raise ValueError("fn cannot be none!") 50 51 @wraps(fn) 52 def wrapper(*args, **kwargs): 53 if 'mode' not in kwargs: 54 raise ValueError("mode not provided.") 55 mode = kwargs['mode'].lower() 56 if mode not in ['pynative', 'graph', 'kbk']: 57 raise ValueError( 58 "Invalid mode. Available option: ['pynative', 'graph', 'kbk'].") 59 60 del kwargs['mode'] 61 if mode == "graph": 62 return (jit(fn, jit_config=JitConfig(jit_level="O2")))(*args, **kwargs) 63 if mode == "kbk": 64 return (jit(fn, jit_config=JitConfig(jit_level="O0")))(*args, **kwargs) 65 return fn(*args, **kwargs) 66 67 setattr(wrapper, "__wrapped_with_mode__", True) 68 return wrapper 69 70 71def run_with_cell_ext(jit_config=None): 72 def cell_wrap_fn(fn): 73 if fn is None: 74 raise ValueError("fn cannot be none!") 75 76 @wraps(fn) 77 def wrapper(*args, **kwargs): 78 cell_obj = Net(fn) 79 if jit_config: 80 cell_obj.set_jit_config(jit_config) 81 return cell_obj(*args, **kwargs) 82 83 return wrapper 84 85 return cell_wrap_fn 86 87 88def to_cell_obj(fn): 89 cell_obj = Net(fn) 90 return cell_obj 91 92 93def compare(output, expect): 94 ''' 95 :param output: Tensor, including tuple/list of tensor 96 :param expect: Numpy array, including tuple/list of Numpy array 97 :return: 98 ''' 99 if isinstance(output, (tuple, list)): 100 for o_ele, e_ele in zip(output, expect): 101 compare(o_ele, e_ele) 102 else: 103 if expect.dtype == np.float32: 104 rtol, atol = 1e-4, 1e-4 105 else: 106 rtol, atol = 1e-3, 1e-3 107 if not np.allclose(output.asnumpy(), expect, rtol, atol, equal_nan=True): 108 raise ValueError(f"compare failed \n output: {output.asnumpy()}\n expect: {expect}") 109 110 111def get_inputs_np(shapes, dtypes): 112 np.random.seed(10) 113 inputs_np = [] 114 for shape, dtype in zip(shapes, dtypes): 115 inputs_np.append(np.random.randn(*shape).astype(dtype)) 116 return inputs_np 117 118 119def get_inputs_tensor(inputs_np): 120 inputs = [] 121 for input_np in inputs_np: 122 inputs.append(Tensor(input_np)) 123 return inputs 124 125 126def need_run_graph_op_mode(func, args, kwargs): 127 if ms.get_context('device_target') != 'Ascend': 128 return False 129 130 # get description of function params expected 131 sig = inspect.signature(func) 132 sig_args = [param.name for param in sig.parameters.values()] 133 134 mode = None 135 if isinstance(kwargs, dict): 136 for key in ['mode', 'context_mode']: 137 if key in sig_args and key in kwargs: 138 mode = kwargs[key] 139 break 140 141 return mode == ms.GRAPH_MODE 142 143 144def run_test_with_On(test_func): 145 146 @wraps(test_func) 147 def wrapper(*args, **kwargs): 148 # call original test function 149 test_func(*args, **kwargs) 150 151 if not need_run_graph_op_mode(test_func, args, kwargs): 152 return 153 154 org_jit_level = ms.get_context('jit_level') 155 try: 156 # run graph in kernel by kernel mode 157 ms.set_context(jit_level='O0') 158 test_func(*args, **kwargs) 159 finally: 160 ms.set_context(jit_level=org_jit_level) 161 162 return wrapper 163