1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test lamb """ 16import numpy as np 17 18import mindspore.nn as nn 19from mindspore import Tensor, Parameter 20from mindspore.common.api import _cell_graph_executor 21from mindspore.nn import TrainOneStepCell, WithLossCell 22from mindspore.nn.optim import Lamb 23from mindspore.ops import operations as P 24import mindspore.common.dtype as mstype 25from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR 26 27 28class LambLearningRate(LearningRateSchedule): 29 def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power): 30 super(LambLearningRate, self).__init__() 31 self.warmup_lr = WarmUpLR(learning_rate, warmup_steps) 32 self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) 33 self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) 34 35 self.greater = P.Greater() 36 self.one = Tensor(np.array([1.0]).astype(np.float32)) 37 self.cast = P.Cast() 38 39 def construct(self, global_step): 40 is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32) 41 warmup_lr = self.warmup_lr(global_step) 42 decay_lr = self.decay_lr(global_step) 43 lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr 44 return lr 45 46 47class Net(nn.Cell): 48 """ Net definition """ 49 50 def __init__(self): 51 super(Net, self).__init__() 52 self.weight = Parameter(Tensor(np.ones([64, 10]).astype(np.float32)), name="weight") 53 self.bias = Parameter(Tensor(np.ones([10]).astype((np.float32))), name="bias") 54 self.matmul = P.MatMul() 55 self.biasAdd = P.BiasAdd() 56 57 def construct(self, x): 58 x = self.biasAdd(self.matmul(x, self.weight), self.bias) 59 return x 60 61 62class NetWithoutWeight(nn.Cell): 63 """ NetWithoutWeight definition """ 64 65 def __init__(self): 66 super(NetWithoutWeight, self).__init__() 67 self.matmul = P.MatMul() 68 69 def construct(self, x): 70 x = self.matmul(x, x) 71 return x 72 73 74def test_lamb_compile_dynamic_lr(): 75 """ test_Lamb_compile """ 76 inputs = Tensor(np.ones([1, 64]).astype(np.float32)) 77 label = Tensor(np.zeros([1, 10]).astype(np.float32)) 78 net = Net() 79 net.set_train() 80 loss = nn.SoftmaxCrossEntropyWithLogits() 81 warmup_decay_lr = LambLearningRate(0.01, 0.0001, 10, 20, 1.0) 82 optimizer = Lamb(net.trainable_params(), warmup_decay_lr) 83 84 net_with_loss = WithLossCell(net, loss) 85 train_network = TrainOneStepCell(net_with_loss, optimizer) 86 _cell_graph_executor.compile(train_network, inputs, label) 87 88 89def test_lamb_compile(): 90 """ test_Lamb_compile """ 91 inputs = Tensor(np.ones([1, 64]).astype(np.float32)) 92 label = Tensor(np.zeros([1, 10]).astype(np.float32)) 93 net = Net() 94 net.set_train() 95 loss = nn.SoftmaxCrossEntropyWithLogits() 96 97 optimizer = Lamb(net.trainable_params(), 0.02, 0.9) 98 99 net_with_loss = WithLossCell(net, loss) 100 train_network = TrainOneStepCell(net_with_loss, optimizer) 101 _cell_graph_executor.compile(train_network, inputs, label) 102 103 104def test_lamb_group(): 105 """ test_Lamb_group_compile """ 106 inputs = Tensor(np.ones([1, 64]).astype(np.float32)) 107 label = Tensor(np.zeros([1, 10]).astype(np.float32)) 108 net = Net() 109 net.set_train() 110 loss = nn.SoftmaxCrossEntropyWithLogits() 111 warmup_decay_lr = LambLearningRate(0.01, 0.0001, 10, 20, 1.0) 112 all_params = net.trainable_params() 113 group_params = [{'params': [all_params[0]], 'lr': warmup_decay_lr, 'weight_decay': 0.9}, 114 {'params': [all_params[1]]}] 115 optimizer = Lamb(group_params, 0.02) 116 117 net_with_loss = WithLossCell(net, loss) 118 train_network = TrainOneStepCell(net_with_loss, optimizer) 119 _cell_graph_executor.compile(train_network, inputs, label) 120