1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16test msssim 17""" 18import numpy as np 19import pytest 20 21import mindspore.common.dtype as mstype 22import mindspore.nn as nn 23from mindspore import Tensor 24from mindspore.common.api import _cell_graph_executor 25 26_MSSSIM_WEIGHTS = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333) 27 28class MSSSIMNet(nn.Cell): 29 def __init__(self, max_val=1.0, power_factors=_MSSSIM_WEIGHTS, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03): 30 super(MSSSIMNet, self).__init__() 31 self.net = nn.MSSSIM(max_val, power_factors, filter_size, filter_sigma, k1, k2) 32 33 def construct(self, img1, img2): 34 return self.net(img1, img2) 35 36 37def test_compile(): 38 factors = (0.033, 0.033, 0.033) 39 net = MSSSIMNet(power_factors=factors) 40 img1 = Tensor(np.random.random((8, 3, 128, 128))) 41 img2 = Tensor(np.random.random((8, 3, 128, 128))) 42 _cell_graph_executor.compile(net, img1, img2) 43 44 45def test_compile_grayscale(): 46 max_val = 255 47 factors = (0.033, 0.033, 0.033) 48 net = MSSSIMNet(max_val=max_val, power_factors=factors) 49 img1 = Tensor(np.random.randint(0, 256, (8, 3, 128, 128), np.uint8)) 50 img2 = Tensor(np.random.randint(0, 256, (8, 3, 128, 128), np.uint8)) 51 _cell_graph_executor.compile(net, img1, img2) 52 53 54def test_msssim_max_val_negative(): 55 max_val = -1 56 with pytest.raises(ValueError): 57 _ = MSSSIMNet(max_val) 58 59 60def test_msssim_max_val_bool(): 61 max_val = True 62 with pytest.raises(TypeError): 63 _ = MSSSIMNet(max_val) 64 65 66def test_msssim_max_val_zero(): 67 max_val = 0 68 with pytest.raises(ValueError): 69 _ = MSSSIMNet(max_val) 70 71 72def test_msssim_power_factors_set(): 73 with pytest.raises(TypeError): 74 _ = MSSSIMNet(power_factors={0.033, 0.033, 0.033}) 75 76 77def test_msssim_filter_size_float(): 78 with pytest.raises(TypeError): 79 _ = MSSSIMNet(filter_size=1.1) 80 81 82def test_msssim_filter_size_zero(): 83 with pytest.raises(ValueError): 84 _ = MSSSIMNet(filter_size=0) 85 86 87def test_msssim_filter_sigma_zero(): 88 with pytest.raises(ValueError): 89 _ = MSSSIMNet(filter_sigma=0.0) 90 91 92def test_msssim_filter_sigma_negative(): 93 with pytest.raises(ValueError): 94 _ = MSSSIMNet(filter_sigma=-0.1) 95 96 97def test_msssim_different_shape(): 98 shape_1 = (8, 3, 128, 128) 99 shape_2 = (8, 3, 256, 256) 100 factors = (0.033, 0.033, 0.033) 101 img1 = Tensor(np.random.random(shape_1)) 102 img2 = Tensor(np.random.random(shape_2)) 103 net = MSSSIMNet(power_factors=factors) 104 with pytest.raises(ValueError): 105 _cell_graph_executor.compile(net, img1, img2) 106 107 108def test_msssim_different_dtype(): 109 dtype_1 = mstype.float32 110 dtype_2 = mstype.float16 111 factors = (0.033, 0.033, 0.033) 112 img1 = Tensor(np.random.random((8, 3, 128, 128)), dtype=dtype_1) 113 img2 = Tensor(np.random.random((8, 3, 128, 128)), dtype=dtype_2) 114 net = MSSSIMNet(power_factors=factors) 115 with pytest.raises(TypeError): 116 _cell_graph_executor.compile(net, img1, img2) 117 118 119def test_msssim_invalid_5d_input(): 120 shape_1 = (8, 3, 128, 128) 121 shape_2 = (8, 3, 256, 256) 122 invalid_shape = (8, 3, 128, 128, 1) 123 factors = (0.033, 0.033, 0.033) 124 img1 = Tensor(np.random.random(shape_1)) 125 invalid_img1 = Tensor(np.random.random(invalid_shape)) 126 img2 = Tensor(np.random.random(shape_2)) 127 invalid_img2 = Tensor(np.random.random(invalid_shape)) 128 129 net = MSSSIMNet(power_factors=factors) 130 with pytest.raises(ValueError): 131 _cell_graph_executor.compile(net, invalid_img1, img2) 132 with pytest.raises(ValueError): 133 _cell_graph_executor.compile(net, img1, invalid_img2) 134 with pytest.raises(ValueError): 135 _cell_graph_executor.compile(net, invalid_img1, invalid_img2) 136