1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""utility functions for mindspore.numpy st tests""" 16import functools 17import numpy as onp 18from mindspore import Tensor 19import mindspore.numpy as mnp 20 21 22def match_array(actual, expected, error=0): 23 24 if isinstance(actual, int): 25 actual = onp.asarray(actual) 26 27 if isinstance(expected, (int, tuple)): 28 expected = onp.asarray(expected) 29 30 onp.testing.assert_almost_equal(actual.tolist(), expected.tolist(), 31 decimal=error) 32 33 34def check_all_results(onp_results, mnp_results, error=0): 35 """Check all results from numpy and mindspore.numpy""" 36 for i, _ in enumerate(onp_results): 37 match_array(onp_results[i], mnp_results[i].asnumpy()) 38 39 40def check_all_unique_results(onp_results, mnp_results): 41 """ 42 Check all results from numpy and mindspore.numpy. 43 44 Args: 45 onp_results (Union[tuple of numpy.arrays, numpy.array]) 46 mnp_results (Union[tuple of Tensors, Tensor]) 47 """ 48 for i, _ in enumerate(onp_results): 49 if isinstance(onp_results[i], tuple): 50 for j in range(len(onp_results[i])): 51 match_array(onp_results[i][j], 52 mnp_results[i][j].asnumpy(), error=7) 53 else: 54 match_array(onp_results[i], mnp_results[i].asnumpy(), error=7) 55 56 57def run_non_kw_test(mnp_fn, onp_fn, test_case): 58 """Run tests on functions with non keyword arguments""" 59 for i in range(len(test_case.arrs)): 60 arrs = test_case.arrs[:i] 61 match_res(mnp_fn, onp_fn, *arrs) 62 63 for i in range(len(test_case.scalars)): 64 arrs = test_case.scalars[:i] 65 match_res(mnp_fn, onp_fn, *arrs) 66 67 for i in range(len(test_case.expanded_arrs)): 68 arrs = test_case.expanded_arrs[:i] 69 match_res(mnp_fn, onp_fn, *arrs) 70 71 for i in range(len(test_case.nested_arrs)): 72 arrs = test_case.nested_arrs[:i] 73 match_res(mnp_fn, onp_fn, *arrs) 74 75 76def rand_int(*shape): 77 """return an random integer array with parameter shape""" 78 res = onp.random.randint(low=1, high=5, size=shape) 79 if isinstance(res, onp.ndarray): 80 return res.astype(onp.float32) 81 return float(res) 82 83 84# return an random boolean array 85def rand_bool(*shape): 86 return onp.random.rand(*shape) > 0.5 87 88 89def match_res(mnp_fn, onp_fn, *arrs, **kwargs): 90 """Checks results from applying mnp_fn and onp_fn on arrs respectively""" 91 dtype = kwargs.pop('dtype', mnp.float32) 92 mnp_arrs = map(functools.partial(Tensor, dtype=dtype), arrs) 93 error = kwargs.pop('error', 0) 94 mnp_res = mnp_fn(*mnp_arrs, **kwargs) 95 onp_res = onp_fn(*arrs, **kwargs) 96 match_all_arrays(mnp_res, onp_res, error=error) 97 98 99def match_all_arrays(mnp_res, onp_res, error=0): 100 if isinstance(mnp_res, (tuple, list)): 101 assert len(mnp_res) == len(onp_res) 102 for actual, expected in zip(mnp_res, onp_res): 103 match_array(actual.asnumpy(), expected, error) 104 else: 105 match_array(mnp_res.asnumpy(), onp_res, error) 106 107 108def match_meta(actual, expected): 109 # float64 and int64 are not supported, and the default type for 110 # float and int are float32 and int32, respectively 111 if expected.dtype == onp.float64: 112 expected = expected.astype(onp.float32) 113 elif expected.dtype == onp.int64: 114 expected = expected.astype(onp.int32) 115 assert actual.shape == expected.shape 116 assert actual.dtype == expected.dtype 117 118 119def run_binop_test(mnp_fn, onp_fn, test_case, error=0): 120 for arr in test_case.arrs: 121 match_res(mnp_fn, onp_fn, arr, arr, error=error) 122 123 for scalar in test_case.scalars: 124 match_res(mnp_fn, onp_fn, arr, scalar, error=error) 125 match_res(mnp_fn, onp_fn, scalar, arr, error=error) 126 127 for scalar1 in test_case.scalars: 128 for scalar2 in test_case.scalars: 129 match_res(mnp_fn, onp_fn, scalar1, scalar2, error=error) 130 131 for expanded_arr1 in test_case.expanded_arrs: 132 for expanded_arr2 in test_case.expanded_arrs: 133 match_res(mnp_fn, onp_fn, expanded_arr1, expanded_arr2, error=error) 134 135 for broadcastable1 in test_case.broadcastables: 136 for broadcastable2 in test_case.broadcastables: 137 match_res(mnp_fn, onp_fn, broadcastable1, broadcastable2, error=error) 138 139 140def run_unary_test(mnp_fn, onp_fn, test_case, error=0): 141 for arr in test_case.arrs: 142 match_res(mnp_fn, onp_fn, arr, error=error) 143 144 for arr in test_case.scalars: 145 match_res(mnp_fn, onp_fn, arr, error=error) 146 147 for arr in test_case.expanded_arrs: 148 match_res(mnp_fn, onp_fn, arr, error=error) 149 150 151def run_multi_test(mnp_fn, onp_fn, arrs, error=0): 152 mnp_arrs = map(Tensor, arrs) 153 for actual, expected in zip(mnp_fn(*mnp_arrs), onp_fn(*arrs)): 154 match_all_arrays(actual, expected, error) 155 156 157def run_single_test(mnp_fn, onp_fn, arr, error=0): 158 mnp_arr = Tensor(arr) 159 for actual, expected in zip(mnp_fn(mnp_arr), onp_fn(arr)): 160 if isinstance(expected, tuple): 161 for actual_arr, expected_arr in zip(actual, expected): 162 match_array(actual_arr.asnumpy(), expected_arr, error) 163 match_array(actual.asnumpy(), expected, error) 164 165 166def run_logical_test(mnp_fn, onp_fn, test_case): 167 for x1 in test_case.boolean_arrs: 168 for x2 in test_case.boolean_arrs: 169 match_res(mnp_fn, onp_fn, x1, x2, dtype=mnp.bool_) 170 171 172def to_tensor(obj, dtype=None): 173 if dtype is None: 174 res = Tensor(obj) 175 if res.dtype == mnp.float64: 176 res = res.astype(mnp.float32) 177 if res.dtype == mnp.int64: 178 res = res.astype(mnp.int32) 179 else: 180 res = Tensor(obj, dtype) 181 return res 182