1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test_lr_schedule """ 16import numpy as np 17 18from mindspore import Parameter, ParameterTuple, Tensor 19from mindspore.nn import Cell 20from mindspore.nn.optim import Optimizer 21from mindspore.ops.operations import BiasAdd, MatMul 22import mindspore.ops.composite as C 23 24 25grad_by_list = C.GradOperation(get_by_list=True) 26 27 28class Net(Cell): 29 """ Net definition """ 30 31 def __init__(self): 32 super(Net, self).__init__() 33 self.weight = Parameter(Tensor(np.ones([64, 10])), name="weight") 34 self.bias = Parameter(Tensor(np.ones([10])), name="bias") 35 self.matmul = MatMul() 36 self.biasAdd = BiasAdd() 37 38 def construct(self, x): 39 x = self.biasAdd(self.matmul(x, self.weight), self.bias) 40 return x 41 42 43class _TrainOneStepCell(Cell): 44 """ _TrainOneStepCell definition """ 45 46 def __init__(self, network, optimizer): 47 """ 48 Append an optimizer to the training network after that the construct 49 function can be called to create the backward graph. 50 Arguments: 51 network: The training network. 52 Note that loss function should have been added. 53 optimizer: optimizer for updating the weights 54 """ 55 super(_TrainOneStepCell, self).__init__(auto_prefix=False) 56 self.network = network 57 self.weights = ParameterTuple(network.get_parameters()) 58 59 if not isinstance(optimizer, Optimizer): 60 raise TypeError('{} is not an optimizer'.format( 61 type(optimizer).__name__)) 62 63 self.has_lr_schedule = False 64 self.optimizer = optimizer 65 66 def construct(self, data, label, *args): 67 weights = self.weights 68 grads = grad_by_list(self.network, weights)(data, label) 69 if self.lr_schedule: 70 self.schedule.update_lr(*args) 71 return self.optimizer(grads) 72