• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import numpy as np
16import pytest
17
18import mindspore.context as context
19from .optimizer_utils import FakeNet, build_network
20from tests.st.utils import test_utils
21
22
23@pytest.mark.level0
24@pytest.mark.platform_x86_cpu
25@pytest.mark.platform_arm_cpu
26@pytest.mark.platform_x86_gpu_training
27@pytest.mark.platform_arm_ascend_training
28@pytest.mark.platform_x86_ascend_training
29@pytest.mark.env_onecard
30@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
31@test_utils.run_test_with_On
32def test_default_asgd(mode):
33    """
34    Feature: Test ASGD optimizer
35    Description: Test ASGD with default parameter
36    Expectation: Loss values and parameters conform to preset values.
37    """
38    from .optimizer_utils import default_fc1_weight_asgd, \
39        default_fc1_bias_asgd, default_fc2_weight_asgd, default_fc2_bias_asgd
40    context.set_context(mode=mode)
41    config = {'name': 'ASGD', 'lr': 0.01, 'lambd': 1e-4, 'alpha': 0.75, 't0': 1e6, 'weight_decay': 0.0}
42    _, cells = build_network(config, FakeNet())
43    assert np.allclose(cells.ax[0].asnumpy(), default_fc1_weight_asgd, atol=1.e-3)
44    assert np.allclose(cells.ax[1].asnumpy(), default_fc1_bias_asgd, atol=1.e-3)
45    assert np.allclose(cells.ax[2].asnumpy(), default_fc2_weight_asgd, atol=1.e-3)
46    assert np.allclose(cells.ax[3].asnumpy(), default_fc2_bias_asgd, atol=1.e-3)
47
48
49@pytest.mark.level0
50@pytest.mark.platform_x86_cpu
51@pytest.mark.platform_arm_cpu
52@pytest.mark.platform_x86_gpu_training
53@pytest.mark.platform_arm_ascend_training
54@pytest.mark.platform_x86_ascend_training
55@pytest.mark.env_onecard
56@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
57def test_no_default_asgd(mode):
58    """
59    Feature: Test ASGD optimizer
60    Description: Test ASGD with another set of parameter
61    Expectation: Loss values and parameters conform to preset values.
62    """
63    from .optimizer_utils import no_default_fc1_weight_asgd, \
64        no_default_fc1_bias_asgd, no_default_fc2_weight_asgd, no_default_fc2_bias_asgd
65    config = {'name': 'ASGD', 'lr': 0.001, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001}
66    context.set_context(mode=mode)
67    _, cells = build_network(config, FakeNet())
68    assert np.allclose(cells.ax[0].asnumpy(), no_default_fc1_weight_asgd, atol=1.e-3)
69    assert np.allclose(cells.ax[1].asnumpy(), no_default_fc1_bias_asgd, atol=1.e-3)
70    assert np.allclose(cells.ax[2].asnumpy(), no_default_fc2_weight_asgd, atol=1.e-3)
71    assert np.allclose(cells.ax[3].asnumpy(), no_default_fc2_bias_asgd, atol=1.e-3)
72
73
74@pytest.mark.level0
75@pytest.mark.platform_x86_cpu
76@pytest.mark.platform_arm_cpu
77@pytest.mark.platform_x86_gpu_training
78@pytest.mark.platform_arm_ascend_training
79@pytest.mark.platform_x86_ascend_training
80@pytest.mark.env_onecard
81@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
82def test_default_asgd_group(mode):
83    """
84    Feature: Test ASGD optimizer
85    Description: Test ASGD with parameter grouping
86    Expectation: Loss values and parameters conform to preset values.
87    """
88    from .optimizer_utils import no_default_group_fc1_weight_asgd, no_default_group_fc1_bias_asgd, \
89        no_default_group_fc2_weight_asgd, no_default_group_fc2_bias_asgd
90    context.set_context(mode=mode)
91    config = {'name': 'ASGD', 'lr': 0.1, 'lambd': 1e-3, 'alpha': 0.8, 't0': 50., 'weight_decay': 0.001}
92    _, cells = build_network(config, FakeNet(), is_group=True)
93    assert np.allclose(cells.ax[0].asnumpy(), no_default_group_fc1_weight_asgd, atol=1.e-3)
94    assert np.allclose(cells.ax[1].asnumpy(), no_default_group_fc1_bias_asgd, atol=1.e-3)
95    assert np.allclose(cells.ax[2].asnumpy(), no_default_group_fc2_weight_asgd, atol=1.e-3)
96    assert np.allclose(cells.ax[3].asnumpy(), no_default_group_fc2_bias_asgd, atol=1.e-3)
97