1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""test cosine_similarity""" 16import pytest 17import numpy as np 18from sklearn.metrics import pairwise 19from mindspore.nn.metrics import CosineSimilarity 20 21 22def test_cosine_similarity(): 23 """test_cosine_similarity""" 24 test_data = np.array([[5, 8, 3, 2], [5, 8, 3, 2], [4, 2, 3, 4]]) 25 metric = CosineSimilarity() 26 metric.clear() 27 metric.update(test_data) 28 square_matrix = metric.eval() 29 30 assert np.allclose(square_matrix, np.array([[0, 1, 0.78229315], [1, 0, 0.78229315], [0.78229315, 0.78229315, 0]])) 31 32 33def test_cosine_similarity_compare(): 34 """test_cosine_similarity_compare""" 35 test_data = np.array([[5, 8, 3, 2], [5, 8, 3, 2], [4, 2, 3, 4]]) 36 metric = CosineSimilarity(similarity='cosine', reduction='none', zero_diagonal=False) 37 metric.clear() 38 metric.update(test_data) 39 ms_square_matrix = metric.eval() 40 41 def sklearn_cosine_similarity(test_data, similarity, reduction): 42 """sklearn_cosine_similarity""" 43 metric_func = {'cosine': pairwise.cosine_similarity, 44 'dot': pairwise.linear_kernel}[similarity] 45 46 square_matrix = metric_func(test_data, test_data) 47 if reduction == 'mean': 48 return square_matrix.mean(axis=-1) 49 if reduction == 'sum': 50 return square_matrix.sum(axis=-1) 51 return square_matrix 52 53 sk_square_matrix = sklearn_cosine_similarity(test_data, similarity='cosine', reduction='none') 54 55 assert np.allclose(sk_square_matrix, ms_square_matrix) 56 57 58def test_cosine_similarity_init1(): 59 """test_cosine_similarity_init1""" 60 with pytest.raises(ValueError): 61 CosineSimilarity(similarity="4") 62 63 64def test_cosine_similarity_init2(): 65 """test_cosine_similarity_init2""" 66 with pytest.raises(TypeError): 67 CosineSimilarity(similarity=4) 68 69 70def test_cosine_similarity_init3(): 71 """test_cosine_similarity_init3""" 72 with pytest.raises(TypeError): 73 CosineSimilarity(reduction=2) 74 75 76def test_cosine_similarity_init4(): 77 """test_cosine_similarity_init4""" 78 with pytest.raises(ValueError): 79 CosineSimilarity(reduction="1") 80 81 82 83def test_cosine_similarity_init5(): 84 """test_cosine_similarity_init5""" 85 with pytest.raises(TypeError): 86 CosineSimilarity(zero_diagonal=3) 87 88 89def test_cosine_similarity_runtime(): 90 """test_cosine_similarity_runtime""" 91 metric = CosineSimilarity() 92 metric.clear() 93 94 with pytest.raises(RuntimeError): 95 metric.eval() 96