1import numpy as onp 2from mindspore import Tensor 3from mindspore import dtype as mstype 4from mindspore import ops 5from mindspore._c_expression import get_code_extra 6 7 8def get_empty_tensor(dtype=mstype.float32): 9 x = Tensor([1], dtype) 10 output = ops.slice(x, (0,), (0,)) 11 return output 12 13 14def match_array(actual, expected, error=0, err_msg=''): 15 if isinstance(actual, (int, tuple, list, bool)): 16 actual = onp.asarray(actual) 17 18 if isinstance(actual, Tensor): 19 actual = actual.asnumpy() 20 21 if isinstance(expected, (int, tuple, list, bool)): 22 expected = onp.asarray(expected) 23 24 if isinstance(expected, Tensor): 25 expected = expected.asnumpy() 26 27 if error > 0: 28 onp.testing.assert_almost_equal( 29 actual, expected, decimal=error, err_msg=err_msg) 30 else: 31 onp.testing.assert_equal(actual, expected, err_msg=err_msg) 32 33 34def _count_unequal_element(data_expected, data_me, rtol, atol): 35 assert data_expected.shape == data_me.shape 36 total_count = len(data_expected.flatten()) 37 error = onp.abs(data_expected - data_me) 38 greater = onp.greater(error, atol + onp.abs(data_me) * rtol) 39 loss_count = onp.count_nonzero(greater) 40 assert (loss_count / total_count) < rtol, \ 41 "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \ 42 format(data_expected[greater], data_me[greater], error[greater]) 43 44 45def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True): 46 if onp.any(onp.isnan(data_expected)) or onp.any(onp.isnan(data_me)): 47 assert onp.allclose(data_expected, data_me, rtol, 48 atol, equal_nan=equal_nan) 49 elif not onp.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan): 50 _count_unequal_element(data_expected, data_me, rtol, atol) 51 else: 52 assert onp.array(data_expected).shape == onp.array(data_me).shape 53 54 55def tensor_to_numpy(data): 56 if isinstance(data, Tensor): 57 return data.asnumpy() 58 elif isinstance(data, tuple): 59 if len(data) == 1: 60 return tensor_to_numpy(data[0]), 61 else: 62 return (tensor_to_numpy(data[0]), *tensor_to_numpy(data[1:])) 63 else: 64 assert False, 'unsupported data type' 65 66 67def nptype_to_mstype(type_): 68 """ 69 Convert MindSpore dtype to torch type. 70 71 Args: 72 type_ (:class:`mindspore.dtype`): MindSpore's dtype. 73 74 Returns: 75 The data type of torch. 76 """ 77 78 return { 79 onp.bool_: mstype.bool_, 80 onp.int8: mstype.int8, 81 onp.int16: mstype.int16, 82 onp.int32: mstype.int32, 83 onp.int64: mstype.int64, 84 onp.uint8: mstype.uint8, 85 onp.float16: mstype.float16, 86 onp.float32: mstype.float32, 87 onp.float64: mstype.float64, 88 onp.complex64: mstype.complex64, 89 onp.complex128: mstype.complex128, 90 None: None 91 }[type_] 92 93def is_empty(variable): 94 if variable is None: 95 return True 96 if isinstance(variable, str) and variable == "": 97 return True 98 if isinstance(variable, (list, tuple, dict, set)) and len(variable) == 0: 99 return True 100 return False 101 102def assert_executed_by_graph_mode(func): 103 jcr = get_code_extra(func) 104 assert jcr is not None 105 assert jcr['stat'] == 'GRAPH_CALLABLE' 106 assert jcr['break_count_'] == 0 107 assert len(jcr['code']['phase_']) > 0 108