• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""CosineSimilarity."""
16import numpy as np
17from mindspore._checkparam import Validator as validator
18from .metric import Metric, rearrange_inputs
19
20
21class CosineSimilarity(Metric):
22    """
23    Computes representation similarity
24
25    Args:
26        similarity (str): 'dot' or 'cosine'. Default: 'cosine'
27        reduction (str): 'none', 'sum', 'mean' (all along dim -1). Default: 'none'
28        zero_diagonal (bool): If true, the diagonals are set to zero. Default: True
29
30    Return:
31        A square matrix (input1, input1) with the similarity scores between all elements.
32        If sum or mean is used, then returns (b, 1) with the reduced value for each row.
33
34    Supported Platforms:
35        ``Ascend`` ``GPU`` ``CPU``
36
37    Example:
38        >>> import numpy as np
39        >>> from mindspore import nn
40        >>>
41        >>> test_data = np.array([[1, 3, 4, 7], [2, 4, 2, 5], [3, 1, 5, 8]])
42        >>> metric = nn.CosineSimilarity()
43        >>> metric.clear()
44        >>> metric.update(test_data)
45        >>> square_matrix = metric.eval()
46        >>> print(square_matrix)
47        [[0.  0.94025615  0.95162452]
48         [0.94025615  0.  0.86146098]
49         [0.95162452  0.86146098  0.]]
50    """
51    def __init__(self, similarity='cosine', reduction='none', zero_diagonal=True):
52        super().__init__()
53        similarity_list = ['dot', 'cosine']
54        reduction_list = ['none', 'sum', 'mean']
55        similarity = validator.check_value_type("similarity", similarity, [str])
56        self.similarity = validator.check_string(similarity, similarity_list, "similarity")
57        reduction = validator.check_value_type("reduction", reduction, [str])
58        self.reduction = validator.check_string(reduction, reduction_list, "reduction")
59        self.zero_diagonal = validator.check_value_type("zero_diagonal", zero_diagonal, [bool])
60        self.clear()
61
62    def clear(self):
63        """Clears the internal evaluation result."""
64        self.sqr_mtx_res = 0
65        self._is_update = False
66
67    @rearrange_inputs
68    def update(self, inputs):
69        """
70        Updates the internal evaluation result with 'input1'.
71
72        Args:
73            inputs: input_data `input1`. The input_data is a `Tensor` or an array.
74        """
75        input_data = self._convert_data(inputs)
76
77        if self.similarity == 'cosine':
78            data = np.linalg.norm(input_data, ord=2, axis=1)
79            input_data = input_data / np.expand_dims(data, 1)
80
81        self.sqr_mtx_res = np.dot(input_data, input_data.transpose(1, 0))
82        self._is_update = True
83
84    def eval(self):
85        """
86         Computes the Cosine_Similarity square matrix.
87
88         Returns:
89             A square matrix.
90
91         Raises:
92            RuntimeError: If the update method is not called first, an error will be reported.
93
94        """
95        if not self._is_update:
96            raise RuntimeError('Call the update method before calling eval.')
97
98        if self.zero_diagonal:
99            np.fill_diagonal(self.sqr_mtx_res, 0)
100
101        if self.reduction == 'mean':
102            self.sqr_mtx_res = np.mean(self.sqr_mtx_res, axis=-1)
103
104        if self.reduction == 'sum':
105            self.sqr_mtx_res = np.sum(self.sqr_mtx_res, axis=-1)
106
107        return self.sqr_mtx_res
108