• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""test cosine_similarity"""
16import pytest
17import numpy as np
18from sklearn.metrics import pairwise
19from mindspore.nn.metrics import CosineSimilarity
20
21
22def test_cosine_similarity():
23    """test_cosine_similarity"""
24    test_data = np.array([[5, 8, 3, 2], [5, 8, 3, 2], [4, 2, 3, 4]])
25    metric = CosineSimilarity()
26    metric.clear()
27    metric.update(test_data)
28    square_matrix = metric.eval()
29
30    assert np.allclose(square_matrix, np.array([[0, 1, 0.78229315], [1, 0, 0.78229315], [0.78229315, 0.78229315, 0]]))
31
32
33def test_cosine_similarity_compare():
34    """test_cosine_similarity_compare"""
35    test_data = np.array([[5, 8, 3, 2], [5, 8, 3, 2], [4, 2, 3, 4]])
36    metric = CosineSimilarity(similarity='cosine', reduction='none', zero_diagonal=False)
37    metric.clear()
38    metric.update(test_data)
39    ms_square_matrix = metric.eval()
40
41    def sklearn_cosine_similarity(test_data, similarity, reduction):
42        """sklearn_cosine_similarity"""
43        metric_func = {'cosine': pairwise.cosine_similarity,
44                       'dot': pairwise.linear_kernel}[similarity]
45
46        square_matrix = metric_func(test_data, test_data)
47        if reduction == 'mean':
48            return square_matrix.mean(axis=-1)
49        if reduction == 'sum':
50            return square_matrix.sum(axis=-1)
51        return square_matrix
52
53    sk_square_matrix = sklearn_cosine_similarity(test_data, similarity='cosine', reduction='none')
54
55    assert np.allclose(sk_square_matrix, ms_square_matrix)
56
57
58def test_cosine_similarity_init1():
59    """test_cosine_similarity_init1"""
60    with pytest.raises(ValueError):
61        CosineSimilarity(similarity="4")
62
63
64def test_cosine_similarity_init2():
65    """test_cosine_similarity_init2"""
66    with pytest.raises(TypeError):
67        CosineSimilarity(similarity=4)
68
69
70def test_cosine_similarity_init3():
71    """test_cosine_similarity_init3"""
72    with pytest.raises(TypeError):
73        CosineSimilarity(reduction=2)
74
75
76def test_cosine_similarity_init4():
77    """test_cosine_similarity_init4"""
78    with pytest.raises(ValueError):
79        CosineSimilarity(reduction="1")
80
81
82
83def test_cosine_similarity_init5():
84    """test_cosine_similarity_init5"""
85    with pytest.raises(TypeError):
86        CosineSimilarity(zero_diagonal=3)
87
88
89def test_cosine_similarity_runtime():
90    """test_cosine_similarity_runtime"""
91    metric = CosineSimilarity()
92    metric.clear()
93
94    with pytest.raises(RuntimeError):
95        metric.eval()
96