• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""test sparsify"""
16import platform
17import numpy as np
18import pytest
19import scipy
20import scipy.sparse.linalg
21from scipy.linalg import eigvals
22
23from mindspore import Tensor, CSRTensor, context, ops
24from mindspore import dtype as mstype
25from mindspore.nn import Cell
26from mindspore.rewrite import sparsify, ArgType
27
28
29def to_tensor(obj, tensor_type):
30    if tensor_type == "Tensor":
31        return Tensor(np.array(obj))
32    if tensor_type == "CSRTensor":
33        obj = scipy.sparse.csr_matrix(obj)
34        return CSRTensor(indptr=Tensor(obj.indptr.astype(np.int32)),
35                         indices=Tensor(obj.indices.astype(np.int32)),
36                         values=Tensor(obj.data), shape=obj.shape)
37    return obj
38
39
40def create_sym_pos_matrix(shape, dtype):
41    if len(shape) != 2 or shape[0] != shape[1]:
42        raise ValueError(
43            'Symmetric positive definite matrix must be a square matrix, but has shape: ', shape)
44    n = shape[-1]
45    count = 0
46    while count < 100:
47        x = np.random.random(shape).astype(dtype)
48        a = (np.matmul(x, x.T) + np.eye(n)).astype(dtype)
49        count += 1
50        if np.min(eigvals(a)) > 0:
51            return a
52    raise ValueError('Symmetric positive definite matrix create failed')
53
54
55class Norm(Cell):
56    def __init__(self):
57        # pylint: disable=useless-super-delegation
58        super(Norm, self).__init__()
59
60    def construct(self, x):
61        return ops.sqrt(ops.reduce_sum(x ** 2))
62
63
64class Dot(Cell):
65    def __init__(self):
66        # pylint: disable=useless-super-delegation
67        super(Dot, self).__init__()
68
69    def construct(self, a, b):
70        b_aligned = ops.reshape(b, (b.shape[0], -1))
71        res = ops.matmul(a, b_aligned)
72        res = ops.reshape(res, a.shape[:-1] + b.shape[1:])
73        return res
74
75
76class CG(Cell):
77    def __init__(self):
78        super(CG, self).__init__()
79        self.norm = Norm()
80        self.dot = Dot()
81
82    def construct(self, a, b, x0, m, maxiter, tol, atol):
83        atol = ops.maximum(atol, tol * self.norm(b))
84
85        r = b - self.dot(a, x0)
86        z = p = self.dot(m, r)
87        rho = self.dot(r, z)
88        k = Tensor(0, mstype.int32)
89        x = x0
90        while k < maxiter and self.norm(r) > atol:
91            q = self.dot(a, p)
92            alpha = rho / self.dot(p, q)
93            x = x + alpha * p
94            r = r - alpha * q
95
96            z = self.dot(m, r)
97            rho_ = self.dot(r, z)
98            beta = rho_ / rho
99            p = z + beta * p
100            rho = rho_
101            k += 1
102
103        cond = self.norm(r) > atol
104        return x, ops.select(cond, k, ops.zeros_like(cond).astype(mstype.int32))
105
106
107def to_np(x):
108    if isinstance(x, CSRTensor):
109        return scipy.sparse.csr_matrix((x.values.asnumpy(), x.indices.asnumpy(), x.indptr.asnumpy()), shape=x.shape)
110    return x.asnumpy()
111
112
113@pytest.mark.level1
114@pytest.mark.platform_x86_gpu_training
115@pytest.mark.platform_x86_cpu
116@pytest.mark.env_onecard
117@pytest.mark.parametrize("mode", [context.PYNATIVE_MODE, context.GRAPH_MODE])
118@pytest.mark.parametrize("tensor_type_a", ["Tensor", "CSRTensor"])
119@pytest.mark.parametrize("tensor_type_m", ["Tensor", "CSRTensor"])
120def test_cg(mode, tensor_type_a, tensor_type_m):
121    """
122    Feature: Sparsify scipy.cg
123    Description: test case for sparsify using CG network.
124    Expectation: the result matches mindspore.scipy
125    """
126    if platform.system() != "Linux":
127        return
128    context.set_context(mode=mode)
129    shape = (7, 7)
130    dtype = np.float32
131    maxiter = 3
132    tol = 1e-5
133    a = to_tensor(create_sym_pos_matrix(shape, dtype), tensor_type_a)
134    np.random.seed(0)
135    b = Tensor(np.random.random(shape[:1]).astype(dtype))
136    x0 = ops.zeros_like(b)
137    m = to_tensor(np.eye(shape[0], dtype=dtype), tensor_type_m)
138    sp_res = scipy.sparse.linalg.cg(to_np(a), to_np(b), to_np(x0), M=to_np(m), maxiter=maxiter, atol=tol, tol=tol)
139
140    func = CG()
141    arg_types = {}
142    if tensor_type_a == "CSRTensor":
143        arg_types["a"] = ArgType.CSR
144    if tensor_type_m == "CSRTensor":
145        arg_types["m"] = ArgType.CSR
146    sparse_func = sparsify(func, arg_types)
147    sparsify_res = sparse_func(a, b, x0, m, maxiter, tol, tol)
148
149    assert len(sp_res) == len(sparsify_res)
150    for expect, actual in zip(sp_res, sparsify_res):
151        assert np.allclose(expect, actual.asnumpy(), rtol=1e-3, atol=1e-5)
152