• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16from mindspore.context import ParallelMode
17from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
18                                       _get_parallel_mode)
19from mindspore.common.parameter import ParameterTuple
20from mindspore.ops import composite as C
21from mindspore.ops import functional as F
22from mindspore.ops import operations as P
23from mindspore.nn.cell import Cell
24from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
25
26class TrainOneStepCellWithServerCommunicator(Cell):
27    r"""
28    Network training package class.
29
30    Wraps the network with an optimizer and communicator operators.
31    The resulting Cell is trained with input '\*inputs'.
32    The backward graph will be created in the construct function to update the parameter. Different
33    parallel modes are available for training.
34
35    This cell is used for hybrid training mode for now.
36
37    Args:
38        network (Cell): The training network. The network only supports single output.
39        optimizer (Cell): Optimizer for updating the weights.
40        sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
41
42    Inputs:
43        - **(\*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
44
45    Outputs:
46        Tensor, a scalar Tensor with shape :math:`()`.
47
48    Supported Platforms:
49        ``Ascend`` ``GPU``
50
51    Examples:
52        >>> net = Net()
53        >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
54        >>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
55        >>> #1) Using the WithLossCell existing provide
56        >>> loss_net = nn.WithLossCell(net, loss_fn)
57        >>> train_net = nn.TrainOneStepCell(loss_net, optim)
58        >>>
59        >>> #2) Using user-defined WithLossCell
60        >>> class MyWithLossCell(Cell):
61        ...    def __init__(self, backbone, loss_fn):
62        ...        super(MyWithLossCell, self).__init__(auto_prefix=False)
63        ...        self._backbone = backbone
64        ...        self._loss_fn = loss_fn
65        ...
66        ...    def construct(self, x, y, label):
67        ...        out = self._backbone(x, y)
68        ...        return self._loss_fn(out, label)
69        ...
70        ...    @property
71        ...    def backbone_network(self):
72        ...        return self._backbone
73        ...
74        >>> loss_net = MyWithLossCell(net, loss_fn)
75        >>> train_net = nn.TrainOneStepCell(loss_net, optim)
76    """
77
78    def __init__(self, network, optimizer, sens=1.0):
79        super(TrainOneStepCellWithServerCommunicator, self).__init__(auto_prefix=False)
80        self.network = network
81        self.network.set_grad()
82        self.network.add_flags(defer_inline=True)
83        self.weights = optimizer.parameters
84        self.optimizer = optimizer
85        self.grad = C.GradOperation(get_by_list=True, sens_param=True)
86        self.sens = sens
87        self.reducer_flag = False
88        self.grad_reducer = F.identity
89        self.parallel_mode = _get_parallel_mode()
90        if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
91            self.reducer_flag = True
92        if self.reducer_flag:
93            mean = _get_gradients_mean()
94            degree = _get_device_num()
95            self.grad_reducer = DistributedGradReducer(self.weights, mean, degree)
96
97        self.hyper_map = C.HyperMap()
98
99        self.pull_weight_by_key = P.PullWeight()
100        self.push_weight_by_key = P.PushWeight()
101
102        self.pull_weights,\
103        self.pull_weight_names,\
104        self.pull_weight_indices = self._pull_weight_inputs(self.network.parameters_and_names())
105
106        self.push_weights,\
107        self.push_weight_names,\
108        self.push_weight_indices = self._push_weight_inputs(self.network.parameters_and_names())
109
110    def _pull_from_server(self, weights, names, indices):
111        result = self.hyper_map(F.partial(self.pull_weight_by_key), weights, names, indices)
112        return result
113
114    def _push_to_server(self, weights, names, indices):
115        result = self.hyper_map(F.partial(self.push_weight_by_key), weights, names, indices)
116        return result
117
118    @staticmethod
119    def _pull_weight_inputs(weights):
120        """pull weight by key inputs."""
121        filtered_weights = []
122        weight_names = []
123        weight_indices = []
124        index = 0
125        for weight in weights:
126            if weight[1].pull_weight_from_server:
127                filtered_weights.append(weight[1])
128                weight_names.append(weight[1].name)
129                weight_indices.append(index)
130            index += 1
131
132        return ParameterTuple(filtered_weights), tuple(weight_names), tuple(weight_indices)
133
134    @staticmethod
135    def _push_weight_inputs(weights):
136        """push weight by key inputs."""
137        filtered_weights = []
138        weight_names = []
139        weight_indices = []
140        index = 0
141        for weight in weights:
142            if weight[1].push_weight_to_server:
143                filtered_weights.append(weight[1])
144                weight_names.append(weight[1].name)
145                weight_indices.append(index)
146            index += 1
147
148        return ParameterTuple(filtered_weights), tuple(weight_names), tuple(weight_indices)
149
150    def construct(self, *inputs):
151        weights = self.weights
152        res = self._pull_from_server(self.pull_weights,
153                                     self.pull_weight_names, self.pull_weight_indices)
154        inputs = F.depend(inputs, res)
155        loss = self.network(*inputs)
156        sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
157        grads = self.grad(self.network, weights)(*inputs, sens)
158        grads = self.grad_reducer(grads)
159
160        loss = F.depend(loss, self.optimizer(grads))
161        push_weights = F.depend(self.push_weights, loss)
162        loss = F.depend(loss, self._push_to_server(push_weights,
163                                                   self.push_weight_names, self.push_weight_indices))
164        return loss
165