1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""ut for batchnorm layer""" 16import numpy as np 17import pytest 18 19import mindspore.nn as nn 20from mindspore import Tensor, Parameter 21from mindspore.common.api import _cell_graph_executor 22 23 24def test_bn_pars_valid1(): 25 """ut of BatchNorm parameters' validation""" 26 with pytest.raises(ValueError): 27 nn.BatchNorm2d(num_features=0) 28 29 30def test_bn_pars_valid2(): 31 """ut of BatchNorm parameters' validation""" 32 with pytest.raises(ValueError): 33 nn.BatchNorm2d(num_features=3, momentum=-0.1) 34 35 36def test_bn_init(): 37 """ut of BatchNorm parameters' validation""" 38 bn = nn.BatchNorm2d(num_features=3) 39 40 assert isinstance(bn.gamma, Parameter) 41 assert isinstance(bn.beta, Parameter) 42 assert isinstance(bn.moving_mean, Parameter) 43 assert isinstance(bn.moving_variance, Parameter) 44 45 46class Net(nn.Cell): 47 def __init__(self): 48 super(Net, self).__init__() 49 self.bn = nn.BatchNorm2d(num_features=3) 50 51 def construct(self, input_x): 52 return self.bn(input_x) 53 54 55def test_compile(): 56 net = Net() 57 input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32)) 58 _cell_graph_executor.compile(net, input_data) 59 60 61class GroupNet(nn.Cell): 62 def __init__(self): 63 super(GroupNet, self).__init__() 64 self.group_bn = nn.GroupNorm() 65 66 def construct(self, x): 67 return self.group_bn(x) 68 69 70def test_compile_groupnorm(): 71 net = nn.GroupNorm(16, 64) 72 input_data = Tensor(np.random.rand(1, 64, 256, 256).astype(np.float32)) 73 _cell_graph_executor.compile(net, input_data) 74