• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""common utils for sparse tests"""
16import platform
17from mindspore import Tensor, CSRTensor, COOTensor, context, ops
18import mindspore.common.dtype as mstype
19
20
21def get_platform():
22    return platform.system().lower()
23
24
25def compare_res(tensor_tup, numpy_tup):
26    assert len(tensor_tup) == len(numpy_tup)
27    for item in zip(tensor_tup, numpy_tup):
28        assert (item[0].asnumpy() == item[1]).all()
29
30
31def compare_csr(csr1, csr2):
32    assert isinstance(csr1, CSRTensor)
33    assert isinstance(csr2, CSRTensor)
34    assert (csr1.indptr.asnumpy() == csr2.indptr.asnumpy()).all()
35    assert (csr1.indices.asnumpy() == csr2.indices.asnumpy()).all()
36    assert (csr1.values.asnumpy() == csr2.values.asnumpy()).all()
37    assert csr1.shape == csr2.shape
38
39
40def compare_coo(coo1, coo2):
41    assert isinstance(coo1, COOTensor)
42    assert isinstance(coo2, COOTensor)
43    assert (coo1.indices.asnumpy() == coo2.indices.asnumpy()).all()
44    assert (coo1.values.asnumpy() == coo2.values.asnumpy()).all()
45    assert coo1.shape == coo2.shape
46
47
48def get_csr_tensor():
49    indptr = Tensor([0, 1, 2], dtype=mstype.int32)
50    indices = Tensor([0, 1], dtype=mstype.int32)
51    values = Tensor([1, 2], dtype=mstype.float32)
52    shape = (2, 4)
53    return CSRTensor(indptr, indices, values, shape)
54
55
56def get_csr_components():
57    csr = get_csr_tensor()
58    res = (csr.indptr, csr.indices, csr.values, csr.shape)
59    return res
60
61
62def get_csr_from_scalar(x):
63    indptr = Tensor([0, 1, 1], dtype=mstype.int32)
64    indices = Tensor([2], dtype=mstype.int32)
65    shape = (2, 3)
66    return CSRTensor(indptr, indices, x.reshape(1), shape)
67
68
69def csr_add(csr_tensor, x):
70    return CSRTensor(csr_tensor.indptr, csr_tensor.indices, csr_tensor.values + x, csr_tensor.shape)
71
72
73def forward_grad_net(net, *inputs, mode=context.GRAPH_MODE):
74    context.set_context(mode=mode)
75    forward = net(*inputs)
76    grad = ops.GradOperation(get_all=True)(net)(*inputs)
77    return forward, grad
78