1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""test error""" 16import math 17import numpy as np 18import pytest 19 20from mindspore import Tensor 21from mindspore.nn.metrics import MAE, MSE 22 23 24def test_MAE(): 25 x = Tensor(np.array([0.1, 0.2, 0.6, 0.9])) 26 y = Tensor(np.array([0.1, 0.25, 0.7, 0.9])) 27 error = MAE() 28 error.clear() 29 error.update(x, y) 30 result = error.eval() 31 assert math.isclose(result, 0.15 / 4) 32 33 34def test_input_MAE(): 35 x = Tensor(np.array([0.1, 0.2, 0.6, 0.9])) 36 y = Tensor(np.array([0.1, 0.25, 0.7, 0.9])) 37 error = MAE() 38 error.clear() 39 with pytest.raises(ValueError): 40 error.update(x, y, x) 41 42 43def test_zero_MAE(): 44 error = MAE() 45 with pytest.raises(RuntimeError): 46 error.eval() 47 48 49def test_MSE(): 50 x = Tensor(np.array([0.1, 0.2, 0.6, 0.9])) 51 y = Tensor(np.array([0.1, 0.25, 0.5, 0.9])) 52 error = MSE() 53 error.clear() 54 error.update(x, y) 55 result = error.eval() 56 assert math.isclose(result, 0.0125 / 4) 57 58 59def test_input_MSE(): 60 x = Tensor(np.array([0.1, 0.2, 0.6, 0.9])) 61 y = Tensor(np.array([0.1, 0.25, 0.7, 0.9])) 62 error = MSE() 63 error.clear() 64 with pytest.raises(ValueError): 65 error.update(x, y, x) 66 67 68def test_zero_MSE(): 69 error = MSE() 70 with pytest.raises(RuntimeError): 71 error.eval() 72