1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""test sparsify""" 16import platform 17import numpy as np 18import pytest 19import scipy 20import scipy.sparse.linalg 21from scipy.linalg import eigvals 22 23from mindspore import Tensor, CSRTensor, context, ops 24from mindspore import dtype as mstype 25from mindspore.nn import Cell 26from mindspore.rewrite import sparsify, ArgType 27 28 29def to_tensor(obj, tensor_type): 30 if tensor_type == "Tensor": 31 return Tensor(np.array(obj)) 32 if tensor_type == "CSRTensor": 33 obj = scipy.sparse.csr_matrix(obj) 34 return CSRTensor(indptr=Tensor(obj.indptr.astype(np.int32)), 35 indices=Tensor(obj.indices.astype(np.int32)), 36 values=Tensor(obj.data), shape=obj.shape) 37 return obj 38 39 40def create_sym_pos_matrix(shape, dtype): 41 if len(shape) != 2 or shape[0] != shape[1]: 42 raise ValueError( 43 'Symmetric positive definite matrix must be a square matrix, but has shape: ', shape) 44 n = shape[-1] 45 count = 0 46 while count < 100: 47 x = np.random.random(shape).astype(dtype) 48 a = (np.matmul(x, x.T) + np.eye(n)).astype(dtype) 49 count += 1 50 if np.min(eigvals(a)) > 0: 51 return a 52 raise ValueError('Symmetric positive definite matrix create failed') 53 54 55class Norm(Cell): 56 def __init__(self): 57 # pylint: disable=useless-super-delegation 58 super(Norm, self).__init__() 59 60 def construct(self, x): 61 return ops.sqrt(ops.reduce_sum(x ** 2)) 62 63 64class Dot(Cell): 65 def __init__(self): 66 # pylint: disable=useless-super-delegation 67 super(Dot, self).__init__() 68 69 def construct(self, a, b): 70 b_aligned = ops.reshape(b, (b.shape[0], -1)) 71 res = ops.matmul(a, b_aligned) 72 res = ops.reshape(res, a.shape[:-1] + b.shape[1:]) 73 return res 74 75 76class CG(Cell): 77 def __init__(self): 78 super(CG, self).__init__() 79 self.norm = Norm() 80 self.dot = Dot() 81 82 def construct(self, a, b, x0, m, maxiter, tol, atol): 83 atol = ops.maximum(atol, tol * self.norm(b)) 84 85 r = b - self.dot(a, x0) 86 z = p = self.dot(m, r) 87 rho = self.dot(r, z) 88 k = Tensor(0, mstype.int32) 89 x = x0 90 while k < maxiter and self.norm(r) > atol: 91 q = self.dot(a, p) 92 alpha = rho / self.dot(p, q) 93 x = x + alpha * p 94 r = r - alpha * q 95 96 z = self.dot(m, r) 97 rho_ = self.dot(r, z) 98 beta = rho_ / rho 99 p = z + beta * p 100 rho = rho_ 101 k += 1 102 103 cond = self.norm(r) > atol 104 return x, ops.select(cond, k, ops.zeros_like(cond).astype(mstype.int32)) 105 106 107def to_np(x): 108 if isinstance(x, CSRTensor): 109 return scipy.sparse.csr_matrix((x.values.asnumpy(), x.indices.asnumpy(), x.indptr.asnumpy()), shape=x.shape) 110 return x.asnumpy() 111 112 113@pytest.mark.level1 114@pytest.mark.platform_x86_gpu_training 115@pytest.mark.platform_x86_cpu 116@pytest.mark.env_onecard 117@pytest.mark.parametrize("mode", [context.PYNATIVE_MODE, context.GRAPH_MODE]) 118@pytest.mark.parametrize("tensor_type_a", ["Tensor", "CSRTensor"]) 119@pytest.mark.parametrize("tensor_type_m", ["Tensor", "CSRTensor"]) 120def test_cg(mode, tensor_type_a, tensor_type_m): 121 """ 122 Feature: Sparsify scipy.cg 123 Description: test case for sparsify using CG network. 124 Expectation: the result matches mindspore.scipy 125 """ 126 if platform.system() != "Linux": 127 return 128 context.set_context(mode=mode) 129 shape = (7, 7) 130 dtype = np.float32 131 maxiter = 3 132 tol = 1e-5 133 a = to_tensor(create_sym_pos_matrix(shape, dtype), tensor_type_a) 134 np.random.seed(0) 135 b = Tensor(np.random.random(shape[:1]).astype(dtype)) 136 x0 = ops.zeros_like(b) 137 m = to_tensor(np.eye(shape[0], dtype=dtype), tensor_type_m) 138 sp_res = scipy.sparse.linalg.cg(to_np(a), to_np(b), to_np(x0), M=to_np(m), maxiter=maxiter, atol=tol, tol=tol) 139 140 func = CG() 141 arg_types = {} 142 if tensor_type_a == "CSRTensor": 143 arg_types["a"] = ArgType.CSR 144 if tensor_type_m == "CSRTensor": 145 arg_types["m"] = ArgType.CSR 146 sparse_func = sparsify(func, arg_types) 147 sparsify_res = sparse_func(a, b, x0, m, maxiter, tol, tol) 148 149 assert len(sp_res) == len(sparsify_res) 150 for expect, actual in zip(sp_res, sparsify_res): 151 assert np.allclose(expect, actual.asnumpy(), rtol=1e-3, atol=1e-5) 152