• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16test psnr
17"""
18import numpy as np
19import pytest
20
21import mindspore.nn as nn
22from mindspore import Tensor
23from mindspore.common import dtype as mstype
24from mindspore.common.api import _cell_graph_executor
25
26
27class PSNRNet(nn.Cell):
28    def __init__(self, max_val=1.0):
29        super(PSNRNet, self).__init__()
30        self.net = nn.PSNR(max_val)
31
32    def construct(self, img1, img2):
33        return self.net(img1, img2)
34
35
36def test_compile_psnr():
37    max_val = 1.0
38    net = PSNRNet(max_val)
39    img1 = Tensor(np.random.random((8, 3, 16, 16)))
40    img2 = Tensor(np.random.random((8, 3, 16, 16)))
41    _cell_graph_executor.compile(net, img1, img2)
42
43
44def test_compile_psnr_grayscale():
45    max_val = 255
46    net = PSNRNet(max_val)
47    img1 = Tensor(np.random.randint(0, 256, (8, 1, 16, 16), np.uint8))
48    img2 = Tensor(np.random.randint(0, 256, (8, 1, 16, 16), np.uint8))
49    _cell_graph_executor.compile(net, img1, img2)
50
51
52def test_psnr_max_val_negative():
53    max_val = -1
54    with pytest.raises(ValueError):
55        _ = PSNRNet(max_val)
56
57
58def test_psnr_max_val_bool():
59    max_val = True
60    with pytest.raises(TypeError):
61        _ = PSNRNet(max_val)
62
63
64def test_psnr_max_val_zero():
65    max_val = 0
66    with pytest.raises(ValueError):
67        _ = PSNRNet(max_val)
68
69
70def test_psnr_different_shape():
71    shape_1 = (8, 3, 16, 16)
72    shape_2 = (8, 3, 8, 8)
73    img1 = Tensor(np.random.random(shape_1))
74    img2 = Tensor(np.random.random(shape_2))
75    net = PSNRNet()
76    with pytest.raises(ValueError):
77        _cell_graph_executor.compile(net, img1, img2)
78
79
80def test_psnr_different_dtype():
81    dtype_1 = mstype.float32
82    dtype_2 = mstype.float16
83    img1 = Tensor(np.random.random((8, 3, 16, 16)), dtype=dtype_1)
84    img2 = Tensor(np.random.random((8, 3, 16, 16)), dtype=dtype_2)
85    net = PSNRNet()
86    with pytest.raises(TypeError):
87        _cell_graph_executor.compile(net, img1, img2)
88
89
90def test_psnr_invalid_5d_input():
91    shape_1 = (8, 3, 16, 16)
92    shape_2 = (8, 3, 8, 8)
93    invalid_shape = (8, 3, 16, 16, 1)
94    img1 = Tensor(np.random.random(shape_1))
95    invalid_img1 = Tensor(np.random.random(invalid_shape))
96    img2 = Tensor(np.random.random(shape_2))
97    invalid_img2 = Tensor(np.random.random(invalid_shape))
98
99    net = PSNRNet()
100    with pytest.raises(ValueError):
101        _cell_graph_executor.compile(net, invalid_img1, img2)
102    with pytest.raises(ValueError):
103        _cell_graph_executor.compile(net, img1, invalid_img2)
104    with pytest.raises(ValueError):
105        _cell_graph_executor.compile(net, invalid_img1, invalid_img2)
106