1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16train step wrap 17""" 18import mindspore.nn as nn 19from mindspore import ParameterTuple 20from mindspore.ops import composite as C 21 22 23class TrainStepWrap(nn.Cell): 24 """ 25 TrainStepWrap definition 26 """ 27 28 def __init__(self, network): 29 super(TrainStepWrap, self).__init__() 30 self.network = network 31 self.network.set_train() 32 self.weights = ParameterTuple(network.trainable_params()) 33 self.optimizer = nn.Momentum(self.weights, 0.1, 0.9) 34 self.hyper_map = C.HyperMap() 35 self.grad = C.GradOperation(get_by_list=True) 36 37 def construct(self, x, label): 38 weights = self.weights 39 grads = self.grad(self.network, weights)(x, label) 40 return self.optimizer(grads) 41 42 43class NetWithLossClass(nn.Cell): 44 """ 45 NetWithLossClass definition 46 """ 47 48 def __init__(self, network): 49 super(NetWithLossClass, self).__init__(auto_prefix=False) 50 self.loss = nn.SoftmaxCrossEntropyWithLogits() 51 self.network = network 52 53 def construct(self, x, label): 54 predict = self.network(x) 55 return self.loss(predict, label) 56 57 58def train_step_with_loss_warp(network): 59 return TrainStepWrap(NetWithLossClass(network)) 60 61 62class TrainStepWrap2(nn.Cell): 63 """ 64 TrainStepWrap2 definition 65 """ 66 67 def __init__(self, network, sens): 68 super(TrainStepWrap2, self).__init__() 69 self.network = network 70 self.network.set_train() 71 self.weights = ParameterTuple(network.get_parameters()) 72 self.optimizer = nn.Momentum(self.weights, 0.1, 0.9) 73 self.hyper_map = C.HyperMap() 74 self.grad = C.GradOperation(get_by_list=True, sens_param=True) 75 self.sens = sens 76 77 def construct(self, x): 78 weights = self.weights 79 grads = self.grad(self.network, weights)(x, self.sens) 80 return self.optimizer(grads) 81 82 83def train_step_with_sens(network, sens): 84 return TrainStepWrap2(network, sens) 85 86 87class TrainStepWrapWithoutOpt(nn.Cell): 88 """ 89 TrainStepWrapWithoutOpt definition 90 """ 91 92 def __init__(self, network): 93 super(TrainStepWrapWithoutOpt, self).__init__() 94 self.network = network 95 self.weights = ParameterTuple(network.trainable_params()) 96 self.grad = C.GradOperation(get_by_list=True) 97 98 def construct(self, x, label): 99 grads = self.grad(self.network, self.weights)(x, label) 100 return grads 101 102 103def train_step_without_opt(network): 104 return TrainStepWrapWithoutOpt(NetWithLossClass(network)) 105