• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15# """test_confusion_matrix_metric"""
16import numpy as np
17import pytest
18from mindspore import Tensor
19from mindspore.nn.metrics import ConfusionMatrixMetric
20
21
22def test_confusion_matrix_metric():
23    """test_confusion_matrix_metric"""
24    metric = ConfusionMatrixMetric(skip_channel=True, metric_name="tpr", calculation_method=False)
25    metric.clear()
26    x = Tensor(np.array([[[0], [1]], [[1], [0]]]))
27    y = Tensor(np.array([[[0], [1]], [[0], [1]]]))
28    metric.update(x, y)
29
30    x = Tensor(np.array([[[0], [1]], [[1], [0]]]))
31    y = Tensor(np.array([[[0], [1]], [[1], [0]]]))
32    metric.update(x, y)
33    output = metric.eval()
34
35    assert np.allclose(output, np.array([0.75]))
36
37
38def test_confusion_matrix_metric_update_len():
39    x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]]))
40    metric = ConfusionMatrixMetric(skip_channel=True, metric_name="ppv", calculation_method=True)
41    metric.clear()
42
43    with pytest.raises(ValueError):
44        metric.update(x)
45
46
47def test_confusion_matrix_metric_update_dim():
48    x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]]))
49    y = Tensor(np.array([1, 0]))
50    metric = ConfusionMatrixMetric(skip_channel=True, metric_name="tnr", calculation_method=True)
51    metric.clear()
52
53    with pytest.raises(ValueError):
54        metric.update(y, x)
55
56
57def test_confusion_matrix_metric_init_skip_channel():
58    with pytest.raises(TypeError):
59        ConfusionMatrixMetric(skip_channel=1)
60
61
62def test_confusion_matrix_metric_init_compute_sample():
63    with pytest.raises(TypeError):
64        ConfusionMatrixMetric(calculation_method=1)
65
66
67def test_confusion_matrix_metric_init_metric_name_type():
68    with pytest.raises(TypeError):
69        metric = ConfusionMatrixMetric(skip_channel=True, metric_name=1, calculation_method=False)
70        x = Tensor(np.array([[[0], [1]], [[1], [0]]]))
71        y = Tensor(np.array([[[0], [1]], [[1], [0]]]))
72        metric.update(x, y)
73        output = metric.eval()
74
75        assert np.allclose(output, np.array([0.75]))
76
77
78def test_confusion_matrix_metric_init_metric_name_str():
79    with pytest.raises(NotImplementedError):
80        metric = ConfusionMatrixMetric(skip_channel=True, metric_name="wwwww", calculation_method=False)
81        x = Tensor(np.array([[[0], [1]], [[1], [0]]]))
82        y = Tensor(np.array([[[0], [1]], [[1], [0]]]))
83        metric.update(x, y)
84        output = metric.eval()
85
86        assert np.allclose(output, np.array([0.75]))
87
88
89def test_confusion_matrix_metric_runtime():
90    metric = ConfusionMatrixMetric(skip_channel=True, metric_name="tnr", calculation_method=True)
91    metric.clear()
92
93    with pytest.raises(RuntimeError):
94        metric.eval()
95