• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16test ssim
17"""
18import numpy as np
19import pytest
20
21import mindspore.common.dtype as mstype
22import mindspore.nn as nn
23from mindspore import Tensor
24from mindspore.common.api import _cell_graph_executor
25
26
27class SSIMNet(nn.Cell):
28    def __init__(self, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03):
29        super(SSIMNet, self).__init__()
30        self.net = nn.SSIM(max_val, filter_size, filter_sigma, k1, k2)
31
32    def construct(self, img1, img2):
33        return self.net(img1, img2)
34
35
36def test_compile():
37    net = SSIMNet()
38    img1 = Tensor(np.random.random((8, 3, 16, 16)), mstype.float32)
39    img2 = Tensor(np.random.random((8, 3, 16, 16)), mstype.float32)
40    _cell_graph_executor.compile(net, img1, img2)
41
42
43def test_ssim_max_val_negative():
44    max_val = -1
45    with pytest.raises(ValueError):
46        _ = SSIMNet(max_val)
47
48
49def test_ssim_max_val_bool():
50    max_val = True
51    with pytest.raises(TypeError):
52        _ = SSIMNet(max_val)
53
54
55def test_ssim_max_val_zero():
56    max_val = 0
57    with pytest.raises(ValueError):
58        _ = SSIMNet(max_val)
59
60
61def test_ssim_filter_size_float():
62    with pytest.raises(TypeError):
63        _ = SSIMNet(filter_size=1.1)
64
65
66def test_ssim_filter_size_zero():
67    with pytest.raises(ValueError):
68        _ = SSIMNet(filter_size=0)
69
70
71def test_ssim_filter_sigma_zero():
72    with pytest.raises(ValueError):
73        _ = SSIMNet(filter_sigma=0.0)
74
75
76def test_ssim_filter_sigma_negative():
77    with pytest.raises(ValueError):
78        _ = SSIMNet(filter_sigma=-0.1)
79
80
81def test_ssim_different_shape():
82    shape_1 = (8, 3, 16, 16)
83    shape_2 = (8, 3, 8, 8)
84    img1 = Tensor(np.random.random(shape_1))
85    img2 = Tensor(np.random.random(shape_2))
86    net = SSIMNet()
87    with pytest.raises(TypeError):
88        _cell_graph_executor.compile(net, img1, img2)
89
90
91def test_ssim_different_dtype():
92    dtype_1 = mstype.float32
93    dtype_2 = mstype.float16
94    img1 = Tensor(np.random.random((8, 3, 16, 16)), dtype=dtype_1)
95    img2 = Tensor(np.random.random((8, 3, 16, 16)), dtype=dtype_2)
96    net = SSIMNet()
97    with pytest.raises(TypeError):
98        _cell_graph_executor.compile(net, img1, img2)
99
100
101def test_ssim_invalid_5d_input():
102    shape_1 = (8, 3, 16, 16)
103    shape_2 = (8, 3, 8, 8)
104    invalid_shape = (8, 3, 16, 16, 1)
105    img1 = Tensor(np.random.random(shape_1))
106    invalid_img1 = Tensor(np.random.random(invalid_shape))
107    img2 = Tensor(np.random.random(shape_2))
108    invalid_img2 = Tensor(np.random.random(invalid_shape))
109
110    net = SSIMNet()
111    with pytest.raises(TypeError):
112        _cell_graph_executor.compile(net, invalid_img1, img2)
113    with pytest.raises(TypeError):
114        _cell_graph_executor.compile(net, img1, invalid_img2)
115    with pytest.raises(TypeError):
116        _cell_graph_executor.compile(net, invalid_img1, invalid_img2)
117