1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""utility functions for mindspore.numpy st tests""" 16import functools 17import numpy as onp 18from mindspore import Tensor 19import mindspore.numpy as mnp 20 21 22def match_array(actual, expected, error=0): 23 24 if isinstance(actual, int): 25 actual = onp.asarray(actual) 26 27 if isinstance(expected, (int, tuple)): 28 expected = onp.asarray(expected) 29 30 if error > 0: 31 onp.testing.assert_almost_equal(actual.tolist(), expected.tolist(), 32 decimal=error) 33 else: 34 onp.testing.assert_equal(actual.tolist(), expected.tolist()) 35 36 37def check_all_results(onp_results, mnp_results, error=0): 38 """Check all results from numpy and mindspore.numpy""" 39 for i, _ in enumerate(onp_results): 40 match_array(onp_results[i], mnp_results[i].asnumpy()) 41 42 43def check_all_unique_results(onp_results, mnp_results): 44 """ 45 Check all results from numpy and mindspore.numpy. 46 47 Args: 48 onp_results (Union[tuple of numpy.arrays, numpy.array]) 49 mnp_results (Union[tuple of Tensors, Tensor]) 50 """ 51 for i, _ in enumerate(onp_results): 52 if isinstance(onp_results[i], tuple): 53 for j in range(len(onp_results[i])): 54 match_array(onp_results[i][j], 55 mnp_results[i][j].asnumpy(), error=7) 56 else: 57 match_array(onp_results[i], mnp_results[i].asnumpy(), error=7) 58 59 60def run_non_kw_test(mnp_fn, onp_fn, test_case): 61 """Run tests on functions with non keyword arguments""" 62 for i in range(len(test_case.arrs)): 63 arrs = test_case.arrs[:i] 64 match_res(mnp_fn, onp_fn, *arrs) 65 66 for i in range(len(test_case.scalars)): 67 arrs = test_case.scalars[:i] 68 match_res(mnp_fn, onp_fn, *arrs) 69 70 for i in range(len(test_case.expanded_arrs)): 71 arrs = test_case.expanded_arrs[:i] 72 match_res(mnp_fn, onp_fn, *arrs) 73 74 for i in range(len(test_case.nested_arrs)): 75 arrs = test_case.nested_arrs[:i] 76 match_res(mnp_fn, onp_fn, *arrs) 77 78 79def rand_int(*shape): 80 """return an random integer array with parameter shape""" 81 res = onp.random.randint(low=1, high=5, size=shape) 82 if isinstance(res, onp.ndarray): 83 return res.astype(onp.float32) 84 return float(res) 85 86 87# return an random boolean array 88def rand_bool(*shape): 89 return onp.random.rand(*shape) > 0.5 90 91 92def match_res(mnp_fn, onp_fn, *arrs, **kwargs): 93 """Checks results from applying mnp_fn and onp_fn on arrs respectively""" 94 dtype = kwargs.pop('dtype', mnp.float32) 95 mnp_arrs = map(functools.partial(Tensor, dtype=dtype), arrs) 96 error = kwargs.pop('error', 0) 97 mnp_res = mnp_fn(*mnp_arrs, **kwargs) 98 onp_res = onp_fn(*arrs, **kwargs) 99 match_all_arrays(mnp_res, onp_res, error=error) 100 101 102def match_all_arrays(mnp_res, onp_res, error=0): 103 if isinstance(mnp_res, (tuple, list)): 104 assert len(mnp_res) == len(onp_res) 105 for actual, expected in zip(mnp_res, onp_res): 106 match_array(actual.asnumpy(), expected, error) 107 else: 108 match_array(mnp_res.asnumpy(), onp_res, error) 109 110 111def match_meta(actual, expected): 112 # float64 and int64 are not supported, and the default type for 113 # float and int are float32 and int32, respectively 114 if expected.dtype == onp.float64: 115 expected = expected.astype(onp.float32) 116 elif expected.dtype == onp.int64: 117 expected = expected.astype(onp.int32) 118 assert actual.shape == expected.shape 119 assert actual.dtype == expected.dtype 120 121 122def run_binop_test(mnp_fn, onp_fn, test_case, error=0): 123 for arr in test_case.arrs: 124 match_res(mnp_fn, onp_fn, arr, arr, error=error) 125 126 for scalar in test_case.scalars: 127 match_res(mnp_fn, onp_fn, arr, scalar, error=error) 128 match_res(mnp_fn, onp_fn, scalar, arr, error=error) 129 130 for scalar1 in test_case.scalars: 131 for scalar2 in test_case.scalars: 132 match_res(mnp_fn, onp_fn, scalar1, scalar2, error=error) 133 134 for expanded_arr1 in test_case.expanded_arrs: 135 for expanded_arr2 in test_case.expanded_arrs: 136 match_res(mnp_fn, onp_fn, expanded_arr1, expanded_arr2, error=error) 137 138 for broadcastable1 in test_case.broadcastables: 139 for broadcastable2 in test_case.broadcastables: 140 match_res(mnp_fn, onp_fn, broadcastable1, broadcastable2, error=error) 141 142 143def run_unary_test(mnp_fn, onp_fn, test_case, error=0): 144 for arr in test_case.arrs: 145 match_res(mnp_fn, onp_fn, arr, error=error) 146 147 for arr in test_case.scalars: 148 match_res(mnp_fn, onp_fn, arr, error=error) 149 150 for arr in test_case.expanded_arrs: 151 match_res(mnp_fn, onp_fn, arr, error=error) 152 153 154def run_multi_test(mnp_fn, onp_fn, arrs, error=0): 155 mnp_arrs = map(Tensor, arrs) 156 for actual, expected in zip(mnp_fn(*mnp_arrs), onp_fn(*arrs)): 157 match_all_arrays(actual, expected, error) 158 159 160def run_single_test(mnp_fn, onp_fn, arr, error=0): 161 mnp_arr = Tensor(arr) 162 for actual, expected in zip(mnp_fn(mnp_arr), onp_fn(arr)): 163 if isinstance(expected, tuple): 164 for actual_arr, expected_arr in zip(actual, expected): 165 match_array(actual_arr.asnumpy(), expected_arr, error) 166 match_array(actual.asnumpy(), expected, error) 167 168 169def run_logical_test(mnp_fn, onp_fn, test_case): 170 for x1 in test_case.boolean_arrs: 171 for x2 in test_case.boolean_arrs: 172 match_res(mnp_fn, onp_fn, x1, x2, dtype=mnp.bool_) 173 174 175def to_tensor(obj, dtype=None): 176 if dtype is None: 177 res = Tensor(obj) 178 if res.dtype == mnp.float64: 179 res = res.astype(mnp.float32) 180 if res.dtype == mnp.int64: 181 res = res.astype(mnp.int32) 182 else: 183 res = Tensor(obj, dtype) 184 return res 185