• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""utility functions for mindspore.numpy st tests"""
16import functools
17import numpy as onp
18from mindspore import Tensor
19import mindspore.numpy as mnp
20
21
22def match_array(actual, expected, error=0):
23
24    if isinstance(actual, int):
25        actual = onp.asarray(actual)
26
27    if isinstance(expected, (int, tuple)):
28        expected = onp.asarray(expected)
29
30    if error > 0:
31        onp.testing.assert_almost_equal(actual.tolist(), expected.tolist(),
32                                        decimal=error)
33    else:
34        onp.testing.assert_equal(actual.tolist(), expected.tolist())
35
36
37def check_all_results(onp_results, mnp_results, error=0):
38    """Check all results from numpy and mindspore.numpy"""
39    for i, _ in enumerate(onp_results):
40        match_array(onp_results[i], mnp_results[i].asnumpy())
41
42
43def check_all_unique_results(onp_results, mnp_results):
44    """
45    Check all results from numpy and mindspore.numpy.
46
47    Args:
48        onp_results (Union[tuple of numpy.arrays, numpy.array])
49        mnp_results (Union[tuple of Tensors, Tensor])
50    """
51    for i, _ in enumerate(onp_results):
52        if isinstance(onp_results[i], tuple):
53            for j in range(len(onp_results[i])):
54                match_array(onp_results[i][j],
55                            mnp_results[i][j].asnumpy(), error=7)
56        else:
57            match_array(onp_results[i], mnp_results[i].asnumpy(), error=7)
58
59
60def run_non_kw_test(mnp_fn, onp_fn, test_case):
61    """Run tests on functions with non keyword arguments"""
62    for i in range(len(test_case.arrs)):
63        arrs = test_case.arrs[:i]
64        match_res(mnp_fn, onp_fn, *arrs)
65
66    for i in range(len(test_case.scalars)):
67        arrs = test_case.scalars[:i]
68        match_res(mnp_fn, onp_fn, *arrs)
69
70    for i in range(len(test_case.expanded_arrs)):
71        arrs = test_case.expanded_arrs[:i]
72        match_res(mnp_fn, onp_fn, *arrs)
73
74    for i in range(len(test_case.nested_arrs)):
75        arrs = test_case.nested_arrs[:i]
76        match_res(mnp_fn, onp_fn, *arrs)
77
78
79def rand_int(*shape):
80    """return an random integer array with parameter shape"""
81    res = onp.random.randint(low=1, high=5, size=shape)
82    if isinstance(res, onp.ndarray):
83        return res.astype(onp.float32)
84    return float(res)
85
86
87# return an random boolean array
88def rand_bool(*shape):
89    return onp.random.rand(*shape) > 0.5
90
91
92def match_res(mnp_fn, onp_fn, *arrs, **kwargs):
93    """Checks results from applying mnp_fn and onp_fn on arrs respectively"""
94    dtype = kwargs.pop('dtype', mnp.float32)
95    mnp_arrs = map(functools.partial(Tensor, dtype=dtype), arrs)
96    error = kwargs.pop('error', 0)
97    mnp_res = mnp_fn(*mnp_arrs, **kwargs)
98    onp_res = onp_fn(*arrs, **kwargs)
99    match_all_arrays(mnp_res, onp_res, error=error)
100
101
102def match_all_arrays(mnp_res, onp_res, error=0):
103    if isinstance(mnp_res, (tuple, list)):
104        assert len(mnp_res) == len(onp_res)
105        for actual, expected in zip(mnp_res, onp_res):
106            match_array(actual.asnumpy(), expected, error)
107    else:
108        match_array(mnp_res.asnumpy(), onp_res, error)
109
110
111def match_meta(actual, expected):
112    # float64 and int64 are not supported, and the default type for
113    # float and int are float32 and int32, respectively
114    if expected.dtype == onp.float64:
115        expected = expected.astype(onp.float32)
116    elif expected.dtype == onp.int64:
117        expected = expected.astype(onp.int32)
118    assert actual.shape == expected.shape
119    assert actual.dtype == expected.dtype
120
121
122def run_binop_test(mnp_fn, onp_fn, test_case, error=0):
123    for arr in test_case.arrs:
124        match_res(mnp_fn, onp_fn, arr, arr, error=error)
125
126        for scalar in test_case.scalars:
127            match_res(mnp_fn, onp_fn, arr, scalar, error=error)
128            match_res(mnp_fn, onp_fn, scalar, arr, error=error)
129
130    for scalar1 in test_case.scalars:
131        for scalar2 in test_case.scalars:
132            match_res(mnp_fn, onp_fn, scalar1, scalar2, error=error)
133
134    for expanded_arr1 in test_case.expanded_arrs:
135        for expanded_arr2 in test_case.expanded_arrs:
136            match_res(mnp_fn, onp_fn, expanded_arr1, expanded_arr2, error=error)
137
138    for broadcastable1 in test_case.broadcastables:
139        for broadcastable2 in test_case.broadcastables:
140            match_res(mnp_fn, onp_fn, broadcastable1, broadcastable2, error=error)
141
142
143def run_unary_test(mnp_fn, onp_fn, test_case, error=0):
144    for arr in test_case.arrs:
145        match_res(mnp_fn, onp_fn, arr, error=error)
146
147    for arr in test_case.scalars:
148        match_res(mnp_fn, onp_fn, arr, error=error)
149
150    for arr in test_case.expanded_arrs:
151        match_res(mnp_fn, onp_fn, arr, error=error)
152
153
154def run_multi_test(mnp_fn, onp_fn, arrs, error=0):
155    mnp_arrs = map(Tensor, arrs)
156    for actual, expected in zip(mnp_fn(*mnp_arrs), onp_fn(*arrs)):
157        match_all_arrays(actual, expected, error)
158
159
160def run_single_test(mnp_fn, onp_fn, arr, error=0):
161    mnp_arr = Tensor(arr)
162    for actual, expected in zip(mnp_fn(mnp_arr), onp_fn(arr)):
163        if isinstance(expected, tuple):
164            for actual_arr, expected_arr in zip(actual, expected):
165                match_array(actual_arr.asnumpy(), expected_arr, error)
166        match_array(actual.asnumpy(), expected, error)
167
168
169def run_logical_test(mnp_fn, onp_fn, test_case):
170    for x1 in test_case.boolean_arrs:
171        for x2 in test_case.boolean_arrs:
172            match_res(mnp_fn, onp_fn, x1, x2, dtype=mnp.bool_)
173
174
175def to_tensor(obj, dtype=None):
176    if dtype is None:
177        res = Tensor(obj)
178        if res.dtype == mnp.float64:
179            res = res.astype(mnp.float32)
180        if res.dtype == mnp.int64:
181            res = res.astype(mnp.int32)
182    else:
183        res = Tensor(obj, dtype)
184    return res
185