1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""common utils for sparse tests""" 16import platform 17from mindspore import Tensor, CSRTensor, COOTensor, context, ops 18import mindspore.common.dtype as mstype 19 20 21def get_platform(): 22 return platform.system().lower() 23 24 25def compare_res(tensor_tup, numpy_tup): 26 assert len(tensor_tup) == len(numpy_tup) 27 for item in zip(tensor_tup, numpy_tup): 28 assert (item[0].asnumpy() == item[1]).all() 29 30 31def compare_csr(csr1, csr2): 32 assert isinstance(csr1, CSRTensor) 33 assert isinstance(csr2, CSRTensor) 34 assert (csr1.indptr.asnumpy() == csr2.indptr.asnumpy()).all() 35 assert (csr1.indices.asnumpy() == csr2.indices.asnumpy()).all() 36 assert (csr1.values.asnumpy() == csr2.values.asnumpy()).all() 37 assert csr1.shape == csr2.shape 38 39 40def compare_coo(coo1, coo2): 41 assert isinstance(coo1, COOTensor) 42 assert isinstance(coo2, COOTensor) 43 assert (coo1.indices.asnumpy() == coo2.indices.asnumpy()).all() 44 assert (coo1.values.asnumpy() == coo2.values.asnumpy()).all() 45 assert coo1.shape == coo2.shape 46 47 48def get_csr_tensor(): 49 indptr = Tensor([0, 1, 2], dtype=mstype.int32) 50 indices = Tensor([0, 1], dtype=mstype.int32) 51 values = Tensor([1, 2], dtype=mstype.float32) 52 shape = (2, 4) 53 return CSRTensor(indptr, indices, values, shape) 54 55 56def get_csr_components(): 57 csr = get_csr_tensor() 58 res = (csr.indptr, csr.indices, csr.values, csr.shape) 59 return res 60 61 62def get_csr_from_scalar(x): 63 indptr = Tensor([0, 1, 1], dtype=mstype.int32) 64 indices = Tensor([2], dtype=mstype.int32) 65 shape = (2, 3) 66 return CSRTensor(indptr, indices, x.reshape(1), shape) 67 68 69def csr_add(csr_tensor, x): 70 return CSRTensor(csr_tensor.indptr, csr_tensor.indices, csr_tensor.values + x, csr_tensor.shape) 71 72 73def forward_grad_net(net, *inputs, mode=context.GRAPH_MODE): 74 context.set_context(mode=mode) 75 forward = net(*inputs) 76 grad = ops.GradOperation(get_all=True)(net)(*inputs) 77 return forward, grad 78