• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0(the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http:  // www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == ==
15
16import numpy as np
17import pytest
18
19from mindspore import ops
20from mindspore import context
21from mindspore import Tensor
22
23
24context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
25
26
27@pytest.mark.level1
28@pytest.mark.platform_x86_cpu
29@pytest.mark.env_onecard
30@pytest.mark.parametrize('dtype', [np.float32, np.float64])
31def test_pinv(dtype):
32    """
33    Feature: test ops.pinv functional interface.
34    Description: test cases for pinv for float32 and float64
35    Expectation: the result match with numpy result.
36    """
37    x0 = np.array([[3., 8.], [2., 2.]], dtype=dtype)
38    x1 = np.array([[2., 3.], [4., 6.]], dtype=dtype)
39    x2 = np.array([[0., 1.], [1., 1.], [1., 0.]], dtype=dtype)
40
41    if dtype == np.float32:
42        loss = 1e-4
43    else:
44        loss = 1e-5
45
46    ms_res0 = ops.pinv(Tensor(x0)).asnumpy()
47    ms_res1 = ops.pinv(Tensor(x1)).asnumpy()
48    ms_res2 = ops.pinv(Tensor(x2)).asnumpy()
49
50    np_res0 = np.linalg.pinv(x0)
51    np_res1 = np.linalg.pinv(x1)
52    np_res2 = np.linalg.pinv(x2)
53
54    assert np.allclose(np_res0, ms_res0, loss)
55    assert np.allclose(np_res1, ms_res1, loss)
56    assert np.allclose(np_res2, ms_res2, loss)
57