1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import numpy as np 17import pytest 18 19import mindspore 20import mindspore.context as context 21import mindspore.nn as nn 22from mindspore import Tensor 23from mindspore.common.api import ms_function 24 25context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') 26 27class NetNorm(nn.Cell): 28 def __init__(self): 29 super(NetNorm, self).__init__() 30 31 self.norm_1 = nn.Norm(axis=0) 32 self.norm_2 = nn.Norm(axis=1) 33 self.norm_3 = nn.Norm(axis=-1) 34 self.norm_4 = nn.Norm(axis=-1, keep_dims=True) 35 36 @ms_function 37 def construct(self, indices): 38 return (self.norm_1(indices), 39 self.norm_2(indices), 40 self.norm_3(indices), 41 self.norm_4(indices)) 42 43@pytest.mark.level0 44@pytest.mark.platform_x86_cpu 45@pytest.mark.env_onecard 46def test_norm(): 47 norm = NetNorm() 48 indices = Tensor(np.array([[4, 4, 9, 1], [2, 1, 3, 6]]), mindspore.float32) 49 output = norm(indices) 50 expect_0 = np.array([4.472136, 4.1231055, 9.486833, 6.0827627]).astype(np.float32) 51 expect_1 = np.array([10.677078, 7.071068]).astype(np.float32) 52 expect_2 = np.array([10.677078, 7.071068]).astype(np.float32) 53 expect_3 = np.array([[10.677078], [7.071068]]).astype(np.float32) 54 55 assert (output[0].asnumpy() == expect_0).all() 56 assert (output[1].asnumpy() == expect_1).all() 57 assert (output[2].asnumpy() == expect_2).all() 58 assert (output[3].asnumpy() == expect_3).all() 59