1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test_momentum """ 16import functools 17import numpy as np 18 19import mindspore.nn as nn 20import mindspore.context as context 21from mindspore import Parameter, ParameterTuple, Tensor 22from mindspore.ops import composite as C 23from mindspore.ops import functional as F 24from mindspore.ops import operations as P 25from ..ut_filter import non_graph_engine 26from ....mindspore_test_framework.mindspore_test import mindspore_test 27from ....mindspore_test_framework.pipeline.forward.compile_forward \ 28 import pipeline_for_compile_forward_ge_graph_for_case_by_case_config 29 30# pylint: disable=W0613 31# W0613: unused-argument 32 33 34run_opt = C.MultitypeFuncGraph("run_opt") 35 36 37grad_by_list = C.GradOperation(get_by_list=True) 38 39 40@run_opt.register("Function", "Tensor", "Tensor", "Tensor", 41 "Tensor", "Tensor", 42 "Tensor") 43def tensor_run_opt(opt, iters, learning_rate, momentum, 44 gradient, variable, moment): 45 """ tensor_run_opt """ 46 success = True 47 new_weight = opt(variable, moment, learning_rate, gradient, momentum) 48 success = F.depend(success, F.assign(variable, new_weight)) 49 return success 50 51 52class OptimizerByMomentum(nn.Cell): 53 """ OptimizerByMomentum definition """ 54 55 def __init__(self, weights): 56 super(OptimizerByMomentum, self).__init__() 57 self.learning_rate = Parameter(0.1, name="learning_rate") 58 self.momentum = Parameter(0.05, name="momentum") 59 self.iter = Parameter(0, name="iter") 60 61 self.weights = weights 62 self.moments = weights.clone(prefix="moments", init='zeros') 63 64 self.hyper_map = C.HyperMap() 65 self.opt = P.ApplyMomentum() 66 67 def construct(self, grads): 68 success = True 69 weights = self.weights 70 moments = self.moments 71 success = self.hyper_map(F.partial(run_opt, self.opt, self.iter, 72 self.learning_rate, self.momentum), 73 grads, weights, moments) 74 return success 75 76 77class TrainStepWrap(nn.Cell): 78 """ TrainStepWrap definition """ 79 80 def __init__(self, network): 81 super(TrainStepWrap, self).__init__() 82 self.network = network 83 self.weights = ParameterTuple(network.get_parameters()) 84 self.optimizer = OptimizerByMomentum(self.weights) 85 self.hyper_map = C.HyperMap() 86 87 def construct(self, x, label): 88 weights = self.weights 89 grads = grad_by_list(self.network, weights)(x, label) 90 return self.optimizer(grads) 91 92 93class NetWithLossClass(nn.Cell): 94 """ NetWithLossClass definition """ 95 96 def __init__(self, network): 97 super(NetWithLossClass, self).__init__(auto_prefix=False) 98 self.loss = nn.SoftmaxCrossEntropyWithLogits() 99 self.network = network 100 101 def construct(self, x, label): 102 predict = self.network(x) 103 return self.loss(predict, label) 104 105 106class Net(nn.Cell): 107 """ Net definition """ 108 109 def __init__(self): 110 super(Net, self).__init__() 111 self.weight = Parameter(Tensor(np.ones([64, 10]).astype(np.float32)), name="weight") 112 self.bias = Parameter(Tensor(np.ones([10]).astype(np.float32)), name="bias") 113 self.matmul = P.MatMul() 114 self.biasAdd = P.BiasAdd() 115 116 def construct(self, x): 117 return self.biasAdd(self.matmul(x, self.weight), self.bias) 118 119 120test_case_ops = [ 121 ('Momentum', { 122 'block': TrainStepWrap(NetWithLossClass(Net())), 123 'desc_inputs': [Tensor(np.ones([1, 64]).astype(np.float32)), 124 Tensor(np.zeros([1, 10]).astype(np.float32))]}), 125] 126 127test_case_lists = [test_case_ops] 128test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists) 129# use -k to select certain testcast 130# pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm 131 132 133 134@non_graph_engine 135@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config) 136def test_exec(): 137 context.set_context(mode=context.GRAPH_MODE) 138 return test_exec_case 139