• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""test error"""
16import math
17import numpy as np
18import pytest
19
20from mindspore import Tensor
21from mindspore.nn.metrics import MAE, MSE
22
23
24def test_MAE():
25    x = Tensor(np.array([0.1, 0.2, 0.6, 0.9]))
26    y = Tensor(np.array([0.1, 0.25, 0.7, 0.9]))
27    error = MAE()
28    error.clear()
29    error.update(x, y)
30    result = error.eval()
31    assert math.isclose(result, 0.15 / 4)
32
33
34def test_input_MAE():
35    x = Tensor(np.array([0.1, 0.2, 0.6, 0.9]))
36    y = Tensor(np.array([0.1, 0.25, 0.7, 0.9]))
37    error = MAE()
38    error.clear()
39    with pytest.raises(ValueError):
40        error.update(x, y, x)
41
42
43def test_zero_MAE():
44    error = MAE()
45    with pytest.raises(RuntimeError):
46        error.eval()
47
48
49def test_MSE():
50    x = Tensor(np.array([0.1, 0.2, 0.6, 0.9]))
51    y = Tensor(np.array([0.1, 0.25, 0.5, 0.9]))
52    error = MSE()
53    error.clear()
54    error.update(x, y)
55    result = error.eval()
56    assert math.isclose(result, 0.0125 / 4)
57
58
59def test_input_MSE():
60    x = Tensor(np.array([0.1, 0.2, 0.6, 0.9]))
61    y = Tensor(np.array([0.1, 0.25, 0.7, 0.9]))
62    error = MSE()
63    error.clear()
64    with pytest.raises(ValueError):
65        error.update(x, y, x)
66
67
68def test_zero_MSE():
69    error = MSE()
70    with pytest.raises(RuntimeError):
71        error.eval()
72