• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""ut for batchnorm layer"""
16import numpy as np
17import pytest
18
19import mindspore.nn as nn
20from mindspore import Tensor, Parameter
21from mindspore.common.api import _cell_graph_executor
22
23
24def test_bn_pars_valid1():
25    """ut of BatchNorm parameters' validation"""
26    with pytest.raises(ValueError):
27        nn.BatchNorm2d(num_features=0)
28
29
30def test_bn_pars_valid2():
31    """ut of BatchNorm parameters' validation"""
32    with pytest.raises(ValueError):
33        nn.BatchNorm2d(num_features=3, momentum=-0.1)
34
35
36def test_bn_init():
37    """ut of BatchNorm parameters' validation"""
38    bn = nn.BatchNorm2d(num_features=3)
39
40    assert isinstance(bn.gamma, Parameter)
41    assert isinstance(bn.beta, Parameter)
42    assert isinstance(bn.moving_mean, Parameter)
43    assert isinstance(bn.moving_variance, Parameter)
44
45
46class Net(nn.Cell):
47    def __init__(self):
48        super(Net, self).__init__()
49        self.bn = nn.BatchNorm2d(num_features=3)
50
51    def construct(self, input_x):
52        return self.bn(input_x)
53
54
55def test_compile():
56    net = Net()
57    input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
58    _cell_graph_executor.compile(net, input_data)
59
60
61class GroupNet(nn.Cell):
62    def __init__(self):
63        super(GroupNet, self).__init__()
64        self.group_bn = nn.GroupNorm()
65
66    def construct(self, x):
67        return self.group_bn(x)
68
69
70def test_compile_groupnorm():
71    net = nn.GroupNorm(16, 64)
72    input_data = Tensor(np.random.rand(1, 64, 256, 256).astype(np.float32))
73    _cell_graph_executor.compile(net, input_data)
74