• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import copy
17import numpy as np
18import pytest
19
20import mindspore.context as context
21from mindspore import Tensor
22import mindspore.nn as nn
23from mindspore.ops.operations import _grad_ops as G
24import mindspore.ops.operations as P
25
26
27class LayerNormNet(nn.Cell):
28    def __init__(self, begin_norm_axis, begin_params_axis):
29        super(LayerNormNet, self).__init__()
30        self.layernorm = P.LayerNorm(begin_norm_axis, begin_params_axis)
31
32    def construct(self, x, gamma, beta):
33        return self.layernorm(x, gamma, beta)
34
35
36class LayerNormGradNet(nn.Cell):
37    def __init__(self, begin_norm_axis, begin_params_axis):
38        super(LayerNormGradNet, self).__init__()
39        self.layernorm_grad = G.LayerNormGrad(begin_norm_axis, begin_params_axis)
40
41    def construct(self, dy, x, var, mean, gamma):
42        return self.layernorm_grad(dy, x, var, mean, gamma)
43
44def get_layernorm_output(x, gamma, beta, begin_norm_axis, begin_params_axis, enable_graph_kernel=False):
45    context.set_context(enable_graph_kernel=enable_graph_kernel)
46
47    net = LayerNormNet(begin_norm_axis, begin_params_axis)
48    output = net(x, gamma, beta)
49
50    return output
51
52
53def get_layernorm_grad_output(x, dy, var, mean, gamma, begin_norm_axis, begin_params_axis, enable_graph_kernel=False):
54    context.set_context(enable_graph_kernel=enable_graph_kernel)
55
56    net = LayerNormGradNet(begin_norm_axis, begin_params_axis)
57    output = net(x, dy, var, mean, gamma)
58
59    return output
60
61def get_rtol_atol(dtype):
62    if dtype == np.float16:
63        return 1.e-3, 1.e-3
64    return 1.e-4, 1.e-4
65
66
67def compare_result(expect, output, dtype):
68    rtol, atol = get_rtol_atol(dtype)
69    if isinstance(expect, (list, tuple)):
70        assert isinstance(output, (list, tuple)) and len(expect) == len(output)
71        expect_list = list(expect)
72        output_list = list(output)
73        for e, o in zip(expect_list, output_list):
74            assert np.allclose(e.asnumpy(), o.asnumpy(), rtol, atol, equal_nan=True)
75    else:
76        assert np.allclose(expect.asnumpy(), output.asnumpy(), rtol, atol, equal_nan=True)
77
78
79def test_layernorm(shape, dtype, begin_norm_axis=-1, begin_params_axis=-1):
80    begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(shape)
81    begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(shape)
82    assert 0 <= begin_norm_axis < len(shape)
83    assert 0 <= begin_params_axis < len(shape)
84    normalized_shape = shape[begin_params_axis:]
85
86    np.random.seed(0)
87    # input tensors
88    x = Tensor(np.random.normal(0, 1, shape).astype(dtype))
89    gamma = Tensor(np.random.normal(0, 1, normalized_shape).astype(dtype))
90    beta = Tensor(np.random.normal(0, 1, normalized_shape).astype(dtype))
91
92    expect = get_layernorm_output(x, gamma, beta, begin_norm_axis, begin_params_axis, False)
93    output = get_layernorm_output(x, gamma, beta, begin_norm_axis, begin_params_axis, True)
94
95    compare_result(expect, output, dtype)
96
97
98def test_layernorm_grad(shape, dtype, begin_norm_axis=-1, begin_params_axis=-1):
99    begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(shape)
100    begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(shape)
101    assert 0 <= begin_norm_axis < len(shape)
102    assert 0 <= begin_params_axis < len(shape)
103
104    norm_axis = [i for i in range(begin_norm_axis, len(shape))]
105    norm_shape = copy.deepcopy(shape)
106    for i, _ in enumerate(norm_shape):
107        if i in norm_axis:
108            norm_shape[i] = 1
109    params_shape = shape[begin_params_axis:]
110
111    np.random.seed(0)
112    # input tensors
113    dy = Tensor(np.random.normal(0, 1, shape).astype(dtype))
114    x = Tensor(np.random.normal(0, 1, shape).astype(dtype))
115    var = Tensor(np.random.normal(0, 1, norm_shape).astype(dtype))
116    mean = Tensor(np.random.normal(0, 1, norm_shape).astype(dtype))
117    gamma = Tensor(np.random.normal(0, 1, params_shape).astype(dtype))
118
119    expect = get_layernorm_grad_output(x, dy, var, mean, gamma, begin_norm_axis, begin_params_axis, False)
120    output = get_layernorm_grad_output(x, dy, var, mean, gamma, begin_norm_axis, begin_params_axis, True)
121
122    compare_result(expect, output, dtype)
123
124@pytest.mark.level0
125@pytest.mark.platform_x86_gpu_training
126@pytest.mark.env_onecard
127def test_layernorm_gpu():
128    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
129    test_layernorm([4, 32, 32], np.float32, -1, -1)
130
131
132@pytest.mark.level0
133@pytest.mark.platform_arm_ascend_training
134@pytest.mark.platform_x86_ascend_training
135@pytest.mark.env_onecard
136def test_layernorm_ascend():
137    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
138    test_layernorm([4, 32, 32], np.float16, -1, -1)
139    test_layernorm([4, 32, 32], np.float32, -1, -1)
140
141
142@pytest.mark.level0
143@pytest.mark.platform_x86_gpu_training
144@pytest.mark.env_onecard
145def test_layernorm_grad_gpu():
146    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
147    test_layernorm_grad([4, 32, 32], np.float32, -1, -1)
148
149
150@pytest.mark.level1
151@pytest.mark.platform_arm_ascend_training
152@pytest.mark.platform_x86_ascend_training
153@pytest.mark.env_onecard
154def test_layernorm_grad_ascend():
155    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
156    test_layernorm_grad([2, 16, 32], np.float16, -1, -1)
157    test_layernorm_grad([4, 32, 32], np.float32, -1, -1)
158