• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test_momentum """
16import functools
17import numpy as np
18
19import mindspore.nn as nn
20import mindspore.context as context
21from mindspore import Parameter, ParameterTuple, Tensor
22from mindspore.ops import composite as C
23from mindspore.ops import functional as F
24from mindspore.ops import operations as P
25from ..ut_filter import non_graph_engine
26from ....mindspore_test_framework.mindspore_test import mindspore_test
27from ....mindspore_test_framework.pipeline.forward.compile_forward \
28    import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
29
30# pylint: disable=W0613
31# W0613: unused-argument
32
33
34run_opt = C.MultitypeFuncGraph("run_opt")
35
36
37grad_by_list = C.GradOperation(get_by_list=True)
38
39
40@run_opt.register("Function", "Tensor", "Tensor", "Tensor",
41                  "Tensor", "Tensor",
42                  "Tensor")
43def tensor_run_opt(opt, iters, learning_rate, momentum,
44                   gradient, variable, moment):
45    """ tensor_run_opt """
46    success = True
47    new_weight = opt(variable, moment, learning_rate, gradient, momentum)
48    success = F.depend(success, F.assign(variable, new_weight))
49    return success
50
51
52class OptimizerByMomentum(nn.Cell):
53    """ OptimizerByMomentum definition """
54
55    def __init__(self, weights):
56        super(OptimizerByMomentum, self).__init__()
57        self.learning_rate = Parameter(0.1, name="learning_rate")
58        self.momentum = Parameter(0.05, name="momentum")
59        self.iter = Parameter(0, name="iter")
60
61        self.weights = weights
62        self.moments = weights.clone(prefix="moments", init='zeros')
63
64        self.hyper_map = C.HyperMap()
65        self.opt = P.ApplyMomentum()
66
67    def construct(self, grads):
68        success = True
69        weights = self.weights
70        moments = self.moments
71        success = self.hyper_map(F.partial(run_opt, self.opt, self.iter,
72                                           self.learning_rate, self.momentum),
73                                 grads, weights, moments)
74        return success
75
76
77class TrainStepWrap(nn.Cell):
78    """ TrainStepWrap definition """
79
80    def __init__(self, network):
81        super(TrainStepWrap, self).__init__()
82        self.network = network
83        self.weights = ParameterTuple(network.get_parameters())
84        self.optimizer = OptimizerByMomentum(self.weights)
85        self.hyper_map = C.HyperMap()
86
87    def construct(self, x, label):
88        weights = self.weights
89        grads = grad_by_list(self.network, weights)(x, label)
90        return self.optimizer(grads)
91
92
93class NetWithLossClass(nn.Cell):
94    """ NetWithLossClass definition """
95
96    def __init__(self, network):
97        super(NetWithLossClass, self).__init__(auto_prefix=False)
98        self.loss = nn.SoftmaxCrossEntropyWithLogits()
99        self.network = network
100
101    def construct(self, x, label):
102        predict = self.network(x)
103        return self.loss(predict, label)
104
105
106class Net(nn.Cell):
107    """ Net definition """
108
109    def __init__(self):
110        super(Net, self).__init__()
111        self.weight = Parameter(Tensor(np.ones([64, 10]).astype(np.float32)), name="weight")
112        self.bias = Parameter(Tensor(np.ones([10]).astype(np.float32)), name="bias")
113        self.matmul = P.MatMul()
114        self.biasAdd = P.BiasAdd()
115
116    def construct(self, x):
117        return self.biasAdd(self.matmul(x, self.weight), self.bias)
118
119
120test_case_ops = [
121    ('Momentum', {
122        'block': TrainStepWrap(NetWithLossClass(Net())),
123        'desc_inputs': [Tensor(np.ones([1, 64]).astype(np.float32)),
124                        Tensor(np.zeros([1, 10]).astype(np.float32))]}),
125]
126
127test_case_lists = [test_case_ops]
128test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists)
129# use -k to select certain testcast
130# pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm
131
132
133
134@non_graph_engine
135@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
136def test_exec():
137    context.set_context(mode=context.GRAPH_MODE)
138    return test_exec_case
139