1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16from mindspore.context import ParallelMode 17from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean, 18 _get_parallel_mode) 19from mindspore.common.parameter import ParameterTuple 20from mindspore.ops import composite as C 21from mindspore.ops import functional as F 22from mindspore.ops import operations as P 23from mindspore.nn.cell import Cell 24from mindspore.nn.wrap.grad_reducer import DistributedGradReducer 25 26class TrainOneStepCellWithServerCommunicator(Cell): 27 r""" 28 Network training package class. 29 30 Wraps the network with an optimizer and communicator operators. 31 The resulting Cell is trained with input '\*inputs'. 32 The backward graph will be created in the construct function to update the parameter. Different 33 parallel modes are available for training. 34 35 This cell is used for hybrid training mode for now. 36 37 Args: 38 network (Cell): The training network. The network only supports single output. 39 optimizer (Cell): Optimizer for updating the weights. 40 sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. 41 42 Inputs: 43 - **(\*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`. 44 45 Outputs: 46 Tensor, a scalar Tensor with shape :math:`()`. 47 48 Supported Platforms: 49 ``Ascend`` ``GPU`` 50 51 Examples: 52 >>> net = Net() 53 >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() 54 >>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 55 >>> #1) Using the WithLossCell existing provide 56 >>> loss_net = nn.WithLossCell(net, loss_fn) 57 >>> train_net = nn.TrainOneStepCell(loss_net, optim) 58 >>> 59 >>> #2) Using user-defined WithLossCell 60 >>> class MyWithLossCell(Cell): 61 ... def __init__(self, backbone, loss_fn): 62 ... super(MyWithLossCell, self).__init__(auto_prefix=False) 63 ... self._backbone = backbone 64 ... self._loss_fn = loss_fn 65 ... 66 ... def construct(self, x, y, label): 67 ... out = self._backbone(x, y) 68 ... return self._loss_fn(out, label) 69 ... 70 ... @property 71 ... def backbone_network(self): 72 ... return self._backbone 73 ... 74 >>> loss_net = MyWithLossCell(net, loss_fn) 75 >>> train_net = nn.TrainOneStepCell(loss_net, optim) 76 """ 77 78 def __init__(self, network, optimizer, sens=1.0): 79 super(TrainOneStepCellWithServerCommunicator, self).__init__(auto_prefix=False) 80 self.network = network 81 self.network.set_grad() 82 self.network.add_flags(defer_inline=True) 83 self.weights = optimizer.parameters 84 self.optimizer = optimizer 85 self.grad = C.GradOperation(get_by_list=True, sens_param=True) 86 self.sens = sens 87 self.reducer_flag = False 88 self.grad_reducer = F.identity 89 self.parallel_mode = _get_parallel_mode() 90 if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): 91 self.reducer_flag = True 92 if self.reducer_flag: 93 mean = _get_gradients_mean() 94 degree = _get_device_num() 95 self.grad_reducer = DistributedGradReducer(self.weights, mean, degree) 96 97 self.hyper_map = C.HyperMap() 98 99 self.pull_weight_by_key = P.PullWeight() 100 self.push_weight_by_key = P.PushWeight() 101 102 self.pull_weights,\ 103 self.pull_weight_names,\ 104 self.pull_weight_indices = self._pull_weight_inputs(self.network.parameters_and_names()) 105 106 self.push_weights,\ 107 self.push_weight_names,\ 108 self.push_weight_indices = self._push_weight_inputs(self.network.parameters_and_names()) 109 110 def _pull_from_server(self, weights, names, indices): 111 result = self.hyper_map(F.partial(self.pull_weight_by_key), weights, names, indices) 112 return result 113 114 def _push_to_server(self, weights, names, indices): 115 result = self.hyper_map(F.partial(self.push_weight_by_key), weights, names, indices) 116 return result 117 118 @staticmethod 119 def _pull_weight_inputs(weights): 120 """pull weight by key inputs.""" 121 filtered_weights = [] 122 weight_names = [] 123 weight_indices = [] 124 index = 0 125 for weight in weights: 126 if weight[1].pull_weight_from_server: 127 filtered_weights.append(weight[1]) 128 weight_names.append(weight[1].name) 129 weight_indices.append(index) 130 index += 1 131 132 return ParameterTuple(filtered_weights), tuple(weight_names), tuple(weight_indices) 133 134 @staticmethod 135 def _push_weight_inputs(weights): 136 """push weight by key inputs.""" 137 filtered_weights = [] 138 weight_names = [] 139 weight_indices = [] 140 index = 0 141 for weight in weights: 142 if weight[1].push_weight_to_server: 143 filtered_weights.append(weight[1]) 144 weight_names.append(weight[1].name) 145 weight_indices.append(index) 146 index += 1 147 148 return ParameterTuple(filtered_weights), tuple(weight_names), tuple(weight_indices) 149 150 def construct(self, *inputs): 151 weights = self.weights 152 res = self._pull_from_server(self.pull_weights, 153 self.pull_weight_names, self.pull_weight_indices) 154 inputs = F.depend(inputs, res) 155 loss = self.network(*inputs) 156 sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) 157 grads = self.grad(self.network, weights)(*inputs, sens) 158 grads = self.grad_reducer(grads) 159 160 loss = F.depend(loss, self.optimizer(grads)) 161 push_weights = F.depend(self.push_weights, loss) 162 loss = F.depend(loss, self._push_to_server(push_weights, 163 self.push_weight_names, self.push_weight_indices)) 164 return loss 165