• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import numpy as onp
2from mindspore import Tensor
3from mindspore import dtype as mstype
4from mindspore import ops
5from mindspore._c_expression import get_code_extra
6
7
8def get_empty_tensor(dtype=mstype.float32):
9    x = Tensor([1], dtype)
10    output = ops.slice(x, (0,), (0,))
11    return output
12
13
14def match_array(actual, expected, error=0, err_msg=''):
15    if isinstance(actual, (int, tuple, list, bool)):
16        actual = onp.asarray(actual)
17
18    if isinstance(actual, Tensor):
19        actual = actual.asnumpy()
20
21    if isinstance(expected, (int, tuple, list, bool)):
22        expected = onp.asarray(expected)
23
24    if isinstance(expected, Tensor):
25        expected = expected.asnumpy()
26
27    if error > 0:
28        onp.testing.assert_almost_equal(
29            actual, expected, decimal=error, err_msg=err_msg)
30    else:
31        onp.testing.assert_equal(actual, expected, err_msg=err_msg)
32
33
34def _count_unequal_element(data_expected, data_me, rtol, atol):
35    assert data_expected.shape == data_me.shape
36    total_count = len(data_expected.flatten())
37    error = onp.abs(data_expected - data_me)
38    greater = onp.greater(error, atol + onp.abs(data_me) * rtol)
39    loss_count = onp.count_nonzero(greater)
40    assert (loss_count / total_count) < rtol, \
41        "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
42        format(data_expected[greater], data_me[greater], error[greater])
43
44
45def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True):
46    if onp.any(onp.isnan(data_expected)) or onp.any(onp.isnan(data_me)):
47        assert onp.allclose(data_expected, data_me, rtol,
48                            atol, equal_nan=equal_nan)
49    elif not onp.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan):
50        _count_unequal_element(data_expected, data_me, rtol, atol)
51    else:
52        assert onp.array(data_expected).shape == onp.array(data_me).shape
53
54
55def tensor_to_numpy(data):
56    if isinstance(data, Tensor):
57        return data.asnumpy()
58    elif isinstance(data, tuple):
59        if len(data) == 1:
60            return tensor_to_numpy(data[0]),
61        else:
62            return (tensor_to_numpy(data[0]), *tensor_to_numpy(data[1:]))
63    else:
64        assert False, 'unsupported data type'
65
66
67def nptype_to_mstype(type_):
68    """
69    Convert MindSpore dtype to torch type.
70
71    Args:
72        type_ (:class:`mindspore.dtype`): MindSpore's dtype.
73
74    Returns:
75        The data type of torch.
76    """
77
78    return {
79        onp.bool_: mstype.bool_,
80        onp.int8: mstype.int8,
81        onp.int16: mstype.int16,
82        onp.int32: mstype.int32,
83        onp.int64: mstype.int64,
84        onp.uint8: mstype.uint8,
85        onp.float16: mstype.float16,
86        onp.float32: mstype.float32,
87        onp.float64: mstype.float64,
88        onp.complex64: mstype.complex64,
89        onp.complex128: mstype.complex128,
90        None: None
91    }[type_]
92
93def is_empty(variable):
94    if variable is None:
95        return True
96    if isinstance(variable, str) and variable == "":
97        return True
98    if isinstance(variable, (list, tuple, dict, set)) and len(variable) == 0:
99        return True
100    return False
101
102def assert_executed_by_graph_mode(func):
103    jcr = get_code_extra(func)
104    assert jcr is not None
105    assert jcr['stat'] == 'GRAPH_CALLABLE'
106    assert jcr['break_count_'] == 0
107    assert len(jcr['code']['phase_']) > 0
108