1 2# Copyright 2020 Huawei Technologies Co., Ltd 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================ 16"""Cell_wrapper.""" 17from __future__ import absolute_import 18from __future__ import division 19 20import os 21from types import FunctionType, MethodType 22 23from mindspore import log as logger 24from mindspore.parallel._utils import _get_device_num, _get_gradients_mean,\ 25 _get_parallel_mode, _get_enable_parallel_optimizer, _is_pynative_parallel 26from mindspore.context import ParallelMode, GRAPH_MODE, get_context 27from mindspore import _checkparam as validator 28from mindspore import ops, nn 29from mindspore.common import dtype as mstype 30from mindspore.common.parameter import Parameter, ParameterTuple 31from mindspore.common.tensor import Tensor 32from mindspore.ops.primitive import _primexpr 33from mindspore.ops import composite as C 34from mindspore.ops import functional as F 35from mindspore.ops import operations as P 36from mindspore.ops.operations.comm_ops import _VirtualDataset 37from mindspore.nn.cell import Cell 38from mindspore.nn.wrap.grad_reducer import DistributedGradReducer 39 40_get_datatype = C.MultitypeFuncGraph("_get_datatype") 41 42 43@_get_datatype.register("Tensor") 44def _tensors_get_datatype(param): 45 """ 46 Acquire parameter datatype. 47 48 Args: 49 param (Tensor): The parameter before operation. 50 51 Returns: 52 mstype, the datatype of parameter. 53 """ 54 return F.dtype(param) 55 56 57_cast_datatype = C.MultitypeFuncGraph("_cast_datatype") 58 59 60@_cast_datatype.register("TypeType", "Tensor") 61def _tensors_cast_datatype(datatype, param): 62 """ 63 Cast gradient to datatype. 64 65 Args: 66 datatype (mstype): the destination datatype of parameter. 67 param (Tensor): The parameter before operation. 68 69 Returns: 70 Tensor, the parameter after operation. 71 """ 72 return F.cast(param, datatype) 73 74 75class WithLossCell(Cell): 76 r""" 77 Cell with loss function. 78 79 Wraps the network with loss function. This Cell accepts data and label as inputs and 80 the computed loss will be returned. 81 82 Args: 83 backbone (Cell): The backbone network to wrap. 84 loss_fn (Cell): The loss function used to compute loss. 85 86 Inputs: 87 - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`. 88 - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`. 89 90 Outputs: 91 Tensor, a tensor means the loss value, the shape of which is usually :math:`()`. 92 93 Raises: 94 TypeError: If dtype of `data` or `label` is neither float16 nor float32. 95 96 Supported Platforms: 97 ``Ascend`` ``GPU`` ``CPU`` 98 99 Examples: 100 >>> import mindspore as ms 101 >>> from mindspore import Tensor, nn 102 >>> import numpy as np 103 >>> # Define the network structure of LeNet5. Refer to 104 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 105 >>> net = LeNet5() 106 >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False) 107 >>> net_with_criterion = nn.WithLossCell(net, loss_fn) 108 >>> 109 >>> batch_size = 2 110 >>> data = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32) * 0.01) 111 >>> label = Tensor(np.ones([batch_size, 10]).astype(np.float32)) 112 >>> 113 >>> output_data = net_with_criterion(data, label) 114 """ 115 116 def __init__(self, backbone, loss_fn): 117 super(WithLossCell, self).__init__(auto_prefix=False) 118 self._backbone = backbone 119 self._loss_fn = loss_fn 120 self._get_attr_from_cell(backbone) 121 122 def construct(self, data, label): 123 out = self._backbone(data) 124 return self._loss_fn(out, label) 125 126 @property 127 def backbone_network(self): 128 """ 129 Get the backbone network. 130 131 Returns: 132 Cell, the backbone network. 133 134 Examples: 135 >>> from mindspore import nn 136 >>> # Define the network structure of LeNet5. Refer to 137 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 138 >>> net = LeNet5() 139 >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False) 140 >>> net_with_criterion = nn.WithLossCell(net, loss_fn) 141 >>> backbone = net_with_criterion.backbone_network 142 """ 143 return self._backbone 144 145 146class WithGradCell(Cell): 147 r""" 148 Cell that returns the gradients. 149 150 Wraps the network with backward cell to compute gradients. A network with a loss function is necessary 151 as argument. If loss function in None, the network must be a wrapper of network and loss function. This 152 Cell accepts '\*inputs' as inputs and returns gradients for each trainable parameter. 153 154 Note: 155 Run in PyNative mode. 156 157 Args: 158 network (Cell): The target network to wrap. The network only supports single output. 159 loss_fn (Cell): Primitive loss function used to compute gradients. Default: ``None`` . 160 sens (Union[None, Tensor, Scalar, Tuple ...]): The sensitive for backpropagation, the type and shape 161 must be same as the `network` output. If ``None`` , we will fill one to a same type shape of 162 output value. Default: ``None`` . 163 164 Inputs: 165 - **\*inputs** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`. 166 167 Outputs: 168 list, a list of Tensors with identical shapes as trainable weights. 169 170 Raises: 171 TypeError: If `sens` is not one of None, Tensor, Scalar or Tuple. 172 173 Supported Platforms: 174 ``Ascend`` ``GPU`` ``CPU`` 175 176 Examples: 177 >>> import mindspore as ms 178 >>> from mindspore import nn 179 >>> # Defined a network without loss function, taking LeNet5 as an example. 180 >>> # Refer to https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 181 >>> net = LeNet5() 182 >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() 183 >>> grad_net = nn.WithGradCell(net, loss_fn) 184 >>> 185 >>> # For a network wrapped with loss function 186 >>> net = Net() 187 >>> net_with_criterion = nn.WithLossCell(net, loss_fn) 188 >>> grad_net = nn.WithGradCell(net_with_criterion) 189 """ 190 191 def __init__(self, network, loss_fn=None, sens=None): 192 super(WithGradCell, self).__init__(auto_prefix=False) 193 self.network = network 194 self.loss_fn = loss_fn 195 self.weights = ParameterTuple(network.trainable_params()) 196 self.grad = C.GradOperation(get_by_list=True, sens_param=(sens is not None)) 197 self.sens = sens 198 if loss_fn is None: 199 self.network_with_loss = network 200 else: 201 self.network_with_loss = WithLossCell(self.network, self.loss_fn) 202 self.network_with_loss.set_train() 203 self._get_attr_from_cell(network) 204 205 def construct(self, *inputs): 206 weights = self.weights 207 if self.sens is None: 208 grads = self.grad(self.network_with_loss, weights)(*inputs) 209 else: 210 grads = self.grad(self.network_with_loss, weights)(*inputs, self.sens) 211 return grads 212 213 214class ForwardValueAndGrad(Cell): 215 r""" 216 Encapsulate training network. 217 218 Including the network and a gradient function. The resulting Cell is trained with input '\*inputs'. 219 The backward graph will be created in the gradient function to calculating gradient. 220 221 Args: 222 network (Union[Cell, Function, MethodType]): The training network. 223 weights (ParameterTuple): The parameters of the training network that need to calculate the gradient. 224 Default: ``None`` . 225 get_all (bool): If ``True`` , get all the gradients with respect to inputs. Default: ``False`` . 226 get_by_list (bool): If ``True`` s, get all the gradients with respect to Parameter variables. 227 If get_all and get_by_list are both ``False`` , get the gradient with respect to first input. 228 If get_all and get_by_list are both ``True`` , get the gradients with respect to inputs and Parameter 229 variables at the same time in the form of ((gradients with respect to inputs), 230 (gradients with respect to parameters)). Default: ``False`` . 231 sens_param (bool): Whether to append sensitivity (gradient with respect to output) as input. 232 If sens_param is ``False`` , a 'ones_like(outputs)' sensitivity will be attached automatically. 233 Default: ``False`` . 234 If the sens_param is ``True`` , a sensitivity (gradient with respect to output) needs to be transferred 235 through the input parameter. 236 237 Inputs: 238 - **\*inputs** (Tuple(Tensor...)) - Tuple of inputs with shape :math:`(N, \ldots)`. 239 - **sens** - A sensitivity (gradient with respect to output) as the input of backpropagation. 240 If network has single output, the sens is a tensor. 241 If network has multiple outputs, the sens is the tuple(tensor). 242 243 Outputs: 244 - **forward value** - The result of network forward running. 245 - **gradients** (tuple(tensor)) - The gradients of network parameters and inputs. 246 247 Supported Platforms: 248 ``Ascend`` ``GPU`` ``CPU`` 249 250 Examples: 251 >>> import numpy as np 252 >>> import mindspore 253 >>> from mindspore import Tensor, nn, ops, ParameterTuple, Parameter 254 >>> 255 >>> class Net(nn.Cell): 256 ... def __init__(self): 257 ... super(Net, self).__init__() 258 ... self.weight = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="weight") 259 ... self.matmul = ops.MatMul() 260 ... 261 ... def construct(self, x): 262 ... out = self.matmul(x, self.weight) 263 ... return out 264 ... 265 >>> net = Net() 266 >>> criterion = nn.SoftmaxCrossEntropyWithLogits() 267 >>> net_with_criterion = nn.WithLossCell(net, criterion) 268 >>> weight = ParameterTuple(net.trainable_params()) 269 >>> train_network = nn.ForwardValueAndGrad(net_with_criterion, weights=weight, get_all=True, get_by_list=True) 270 >>> inputs = Tensor(np.ones([1, 2]).astype(np.float32)) 271 >>> labels = Tensor(np.ones([1, 2]).astype(np.float32)) 272 >>> result = train_network(inputs, labels) 273 >>> print(result) 274 (Tensor(shape=[1], dtype=Float32, value= [ 1.38629436e+00]), ((Tensor(shape=[1, 2], dtype=Float32, value= 275 [[ -1.00000000e+00, -1.00000000e+00]]), Tensor(shape=[1, 2], dtype=Float32, value= 276 [[ 0.00000000e+00, 0.00000000e+00]])), (Tensor(shape=[2, 2], dtype=Float32, value= 277 [[ -5.00000000e-01, -5.00000000e-01], 278 [ -5.00000000e-01, -5.00000000e-01]]),))) 279 """ 280 281 def __init__(self, network, weights=None, get_all=False, get_by_list=False, sens_param=False): 282 super(ForwardValueAndGrad, self).__init__(auto_prefix=False) 283 if not isinstance(network, (Cell, FunctionType, MethodType)): 284 raise TypeError(f"For 'ForwardValueAndGrad', " 285 f"the argument 'network' must be cell, function type or method type, " 286 f"but got '{type(network)}'") 287 if not isinstance(get_all, bool): 288 raise TypeError(f"For 'ForwardValueAndGrad', " 289 f"the type of 'get_all' must be bool, but got '{type(get_all)}'") 290 if not isinstance(get_by_list, bool): 291 raise TypeError(f"For 'ForwardValueAndGrad', " 292 f"the type of 'get_by_list' must be bool, but got '{type(get_by_list)}'") 293 if get_by_list and not isinstance(weights, (ParameterTuple, tuple, list)): 294 raise TypeError(f"For 'ForwardValueAndGrad', " 295 f"when 'get_by_list' is set to True, the argument 'weights' must be " 296 f"Parameters array, but got '{type(weights)}'") 297 self.network = network 298 if isinstance(network, Cell): 299 self.network.set_grad() 300 self.weights = weights 301 self.get_all = get_all 302 self.get_by_list = get_by_list 303 self.sens_param = sens_param 304 self.grad = C.GradOperation(get_all=self.get_all, get_by_list=self.get_by_list, sens_param=self.sens_param) 305 self._get_attr_from_cell(network) 306 307 def construct(self, *inputs): 308 grad_inputs = inputs 309 if self.sens_param: 310 inputs = inputs[:-1] 311 loss = self.network(*inputs) 312 if self.get_by_list: 313 grads = self.grad(self.network, self.weights)(*grad_inputs) 314 else: 315 grads = self.grad(self.network)(*grad_inputs) 316 return loss, grads 317 318 319class TrainOneStepCell(Cell): 320 r""" 321 Network training package class. 322 323 Wraps the `network` with the `optimizer`. The resulting Cell is trained with input '\*inputs'. 324 The backward graph will be created in the construct function to update the parameter. Different 325 parallel modes are available for training. 326 327 Args: 328 network (Cell): The training network. The network only supports single output. 329 optimizer (Union[Cell]): Optimizer for updating the network parameters. 330 sens (numbers.Number): The scaling number to be filled as the input of backpropagation. Default value is 331 ``None`` , which is ``1.0`` . 332 return_grad (bool): Whether to return gradient. If ``True``, it will return the gradient in the form of a dict 333 while returning loss. The key of the dict is the parameter name corresponding to the gradient, and value 334 is the gradient value. Default value is ``False`` . 335 336 Inputs: 337 - **\*inputs** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`. 338 339 Outputs: 340 Tensor, a tensor means the loss value, the shape of which is usually :math:`()`. 341 342 Raises: 343 TypeError: If `sens` is not a numbers.Number. 344 345 Supported Platforms: 346 ``Ascend`` ``GPU`` ``CPU`` 347 348 Examples: 349 >>> import mindspore.nn as nn 350 >>> # Define the network structure of LeNet5. Refer to 351 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 352 >>> net = LeNet5() 353 >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() 354 >>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 355 >>> #1) Using the WithLossCell provided by MindSpore 356 >>> loss_net = nn.WithLossCell(net, loss_fn) 357 >>> train_net = nn.TrainOneStepCell(loss_net, optim) 358 >>> 359 >>> #2) Using user-defined WithLossCell 360 >>> class MyWithLossCell(nn.Cell): 361 ... def __init__(self, backbone, loss_fn): 362 ... super(MyWithLossCell, self).__init__(auto_prefix=False) 363 ... self._backbone = backbone 364 ... self._loss_fn = loss_fn 365 ... 366 ... def construct(self, x, y, label): 367 ... out = self._backbone(x, y) 368 ... return self._loss_fn(out, label) 369 ... 370 ... @property 371 ... def backbone_network(self): 372 ... return self._backbone 373 ... 374 >>> loss_net = MyWithLossCell(net, loss_fn) 375 >>> train_net = nn.TrainOneStepCell(loss_net, optim) 376 """ 377 378 def __init__(self, network, optimizer, sens=None, return_grad=False): 379 super(TrainOneStepCell, self).__init__(auto_prefix=False) 380 self.network = network 381 self.network.set_grad() 382 self.optimizer = optimizer 383 self.weights = self.optimizer.parameters 384 self.grad = C.GradOperation(get_by_list=True, sens_param=True) 385 self.grad_no_sens = C.GradOperation(get_by_list=True) 386 self.sens = sens 387 if self.sens == 0: 388 raise ValueError("The input argument of 'sens' can not be 0.") 389 self.sense_flag = True 390 if self.sens is None: 391 self.sense_flag = False 392 self.sens = 1.0 393 self.return_grad = return_grad 394 if return_grad: 395 self.weights_name = [i.name for i in self.optimizer.parameters] 396 self.reducer_flag = False 397 self.grad_reducer = nn.Identity() 398 self.parallel_mode = _get_parallel_mode() 399 self.reducer_flag = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL) or \ 400 _is_pynative_parallel() 401 if self.reducer_flag: 402 self.mean = _get_gradients_mean() 403 self.degree = _get_device_num() 404 from mindspore.communication.management import GlobalComm 405 group = GlobalComm.WORLD_COMM_GROUP 406 if isinstance(self.optimizer, (nn.AdaSumByGradWrapCell, nn.AdaSumByDeltaWeightWrapCell)): 407 from mindspore.communication.management import get_group_size, create_group, get_rank 408 group_number = get_group_size() // 8 409 self.degree = int(self.degree / group_number) 410 group_list = [list(range(x * self.degree, (x + 1) * self.degree)) for x in range(group_number)] 411 current_index = get_rank() // 8 412 server_group_name = "allreduce_" + str(current_index) 413 create_group(server_group_name, group_list[current_index]) 414 group = server_group_name 415 self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree, group=group) 416 self._get_attr_from_cell(network) 417 418 def construct(self, *inputs): 419 if not self.sense_flag: 420 return self._no_sens_impl(*inputs) 421 loss = self.network(*inputs) 422 sens = F.fill(loss.dtype, loss.shape, self.sens) 423 grads = self.grad(self.network, self.weights)(*inputs, sens) 424 grads = self.grad_reducer(grads) 425 loss = F.depend(loss, self.optimizer(grads)) 426 if self.return_grad: 427 grad_with_param_name = {} 428 for index, value in enumerate(grads): 429 grad_with_param_name[self.weights_name[index]] = value 430 return loss, grad_with_param_name 431 return loss 432 433 def _no_sens_impl(self, *inputs): 434 """construct implementation when the 'sens' parameter is passed in.""" 435 loss = self.network(*inputs) 436 grads = self.grad_no_sens(self.network, self.weights)(*inputs) 437 grads = self.grad_reducer(grads) 438 loss = F.depend(loss, self.optimizer(grads)) 439 if self.return_grad: 440 grad_with_param_name = {} 441 for index, value in enumerate(grads): 442 grad_with_param_name[self.weights_name[index]] = value 443 return loss, grad_with_param_name 444 return loss 445 446 447class GetNextSingleOp(Cell): 448 """ 449 Cell to run for getting the next operation. 450 451 For detailed information, refer to :class:`mindspore.ops.GetNext`. 452 453 Args: 454 dataset_types (list[:class:`mindspore.dtype`]): The types of dataset. 455 dataset_shapes (list[tuple[int]]): The shapes of dataset. 456 queue_name (str): Queue name to fetch the data. 457 458 Outputs: 459 tuple[Tensor], the data gets from Dataset. 460 461 Supported Platforms: 462 ``Ascend`` ``GPU`` 463 464 Examples: 465 >>> import mindspore 466 >>> from mindspore import ops, nn 467 >>> from mindspore import dataset as ds 468 >>> from mindspore import dtype as mstype 469 >>> 470 >>> data_path = "/path/to/MNIST_Data/train/" 471 >>> train_dataset = ds.MnistDataset(data_path, num_samples=10) 472 >>> dataset_helper = mindspore.DatasetHelper(train_dataset, dataset_sink_mode=True) 473 >>> dataset = dataset_helper.iter.dataset 474 >>> dataset_types, dataset_shapes = dataset_helper.types_shapes() 475 >>> queue_name = dataset.__transfer_dataset__.queue_name 476 >>> get_next_single_op_net = nn.GetNextSingleOp(dataset_types, dataset_shapes, queue_name) 477 >>> data, label = get_next_single_op_net() 478 >>> relu = ops.ReLU() 479 >>> result = relu(data.astype(mstype.float32)) 480 >>> print(result.shape) 481 (28, 28, 1) 482 """ 483 484 def __init__(self, dataset_types, dataset_shapes, queue_name): 485 super(GetNextSingleOp, self).__init__() 486 self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name) 487 488 def construct(self): 489 return self.get_next() 490 491 492class _VirtualDatasetCell(Cell): 493 """ 494 Wrap the network with virtual dataset to convert data parallel layout to model parallel layout. 495 496 _VirtualDataset is a virtual Primitive, it does not exist in the final executing graph. Inputs and outputs 497 of _VirtualDataset are distributed in data parallel pattern, tensor redistribution Primitives is inserted 498 dynamically during the graph compile process. 499 500 Note: 501 Only used in semi auto parallel and auto parallel mode. 502 503 Args: 504 backbone (Cell): The target network to wrap. 505 506 Examples: 507 >>> net = Net() 508 >>> net = _VirtualDatasetCell(net) 509 """ 510 511 def __init__(self, backbone): 512 super(_VirtualDatasetCell, self).__init__(auto_prefix=False) 513 self._backbone = backbone 514 self._virtual_dataset = _VirtualDataset() 515 self._get_attr_from_cell(backbone) 516 517 def construct(self, *inputs): 518 output = self._virtual_dataset(*inputs) 519 return self._backbone(*output) 520 521 522@_primexpr 523def _check_shape_value_on_axis_divided_by_target_value(input_shape, micro_size): 524 if F.isconstant(input_shape[0]) is False: 525 return 526 if input_shape[0] % micro_size != 0: 527 raise ValueError(f"For micro batch initialization, the 0th dimension shape of input({input_shape[0]}) must be " 528 f"divided by micro size({micro_size})") 529 530 531class _MicroBatch(Cell): 532 """ 533 transform mini-batch to micro-batch in pipeline parallel. 534 535 Args: 536 params (micro_size): The number of micro-batch. 537 """ 538 def __init__(self, micro_size): 539 super(_MicroBatch, self).__init__() 540 self.shape = P.Shape() 541 self.micro_size = micro_size 542 self.strided_slice = P.StridedSlice() 543 544 def construct(self, i, *inputs): 545 """construct for _MicroBatch.""" 546 micro_inputs = () 547 for each_input in inputs: 548 input_shape = self.shape(each_input) 549 _check_shape_value_on_axis_divided_by_target_value(input_shape, self.micro_size) 550 micro_batch_begin = (input_shape[0] // self.micro_size) * i 551 micro_batch_end = (input_shape[0] // self.micro_size) * (i + 1) 552 strided_slice_begin = (micro_batch_begin,) 553 strided_slice_strides = (1,) 554 for _ in range(len(input_shape) - 1): 555 strided_slice_begin += (0,) 556 strided_slice_strides += (1,) 557 strided_slice_end = (micro_batch_end,) 558 strided_slice_end += input_shape[1:] 559 micro_input = self.strided_slice(each_input, strided_slice_begin, strided_slice_end, strided_slice_strides) 560 micro_inputs += (micro_input,) 561 return micro_inputs 562 563 564class MicroBatchInterleaved(Cell): 565 """ 566 This function splits the input at the 0th into interleave_num pieces and then performs 567 the computation of the wrapped cell. Application scenario: When there is model parallelism in semi-automatic mode 568 and network, if the first slice data is calculating forward, the second slice data will execute the 569 communication operators at the same time, to achieve the performance acceleration of communication and computing 570 concurrency. 571 572 Note: 573 The output of the input network must be a single tensor. 574 575 Args: 576 network (Cell): The target network to wrap. 577 interleave_num (int, optional): split num of batch size. Default: ``2`` . 578 579 Inputs: 580 tuple[Tensor]. It's the same with the input of the `network` . 581 582 Outputs: 583 Tensor. The output of the input `network` . 584 585 Supported Platforms: 586 ``Ascend`` ``GPU`` 587 588 Examples: 589 >>> import mindspore.nn as nn 590 >>> # Define the network structure of LeNet5. Refer to 591 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 592 >>> net = LeNet5() 593 >>> net = nn.MicroBatchInterleaved(net, 2) 594 """ 595 def __init__(self, network, interleave_num=2): 596 super(MicroBatchInterleaved, self).__init__(auto_prefix=False) 597 if not isinstance(interleave_num, int): 598 raise TypeError("For 'MicroBatchInterleaved', the argument 'interleave_num' must be integer, " 599 "but got the type : {}.".format(type(interleave_num))) 600 if interleave_num <= 0: 601 raise ValueError("For 'MicroBatchInterleaved', the argument 'interleave_num' must be large than 0, " 602 "but got {}.".format(interleave_num)) 603 self.network = network 604 self.interleave_num = interleave_num 605 self.interleave_inputs = nn.CellList() 606 self.add = P.Add().add_prim_attr("micro_interleaved_add_flag", True) 607 for _ in range(interleave_num): 608 interleave_data = _MicroBatch(interleave_num) 609 interleave_data.strided_slice.add_prim_attr("strided_slice_flag", True) 610 interleave_data.strided_slice.add_prim_attr("interleave_num", interleave_num) 611 self.interleave_inputs.append(interleave_data) 612 self._get_attr_from_cell(network) 613 614 def construct(self, *inputs): 615 output = 0.0 616 for i in range(self.interleave_num): 617 interleave_input = self.interleave_inputs[i](i, *inputs) 618 output = self.add(output, self.network(*interleave_input)) 619 return output 620 621 622class PipelineCell(Cell): 623 """ 624 Slice MiniBatch into finer-grained MicroBatch for use in pipeline-parallel training. 625 626 Note: 627 micro_size must be greater or equal to pipeline stages. 628 629 Args: 630 network (Cell): The target network to wrap. 631 micro_size (int): MicroBatch size. 632 633 Supported Platforms: 634 ``Ascend`` ``GPU`` 635 636 Examples: 637 >>> import mindspore.nn as nn 638 >>> # Define the network structure of LeNet5. Refer to 639 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 640 >>> net = LeNet5() 641 >>> net = nn.PipelineCell(net, 4) 642 """ 643 def __init__(self, network, micro_size): 644 super(PipelineCell, self).__init__(auto_prefix=False) 645 self.network = network 646 self.micro_inputs = nn.CellList() 647 self.micro_size = micro_size 648 self.add_list = [] 649 if not isinstance(network, Cell): 650 raise TypeError("For 'PipelineCell', the argument 'network' must cell type, " 651 "but got the type : {}.".format(type(network))) 652 if not isinstance(micro_size, int): 653 raise TypeError("For 'PipelineCell', the argument 'micro_size' must be integer, " 654 "but got the type : {}.".format(type(micro_size))) 655 if micro_size <= 0: 656 raise ValueError("For 'PipelineCell', the argument 'micro_size' must be large than 0, " 657 "but got {}.".format(micro_size)) 658 for i in range(micro_size): 659 micro_input = _MicroBatch(micro_size) 660 self.micro_inputs.append(micro_input) 661 self.add = P.Add().add_prim_attr("pipeline_end", i) 662 self.add_list.append(self.add) 663 self._get_attr_from_cell(network) 664 665 def construct(self, *inputs): 666 ret = None 667 for i in range(self.micro_size): 668 micro_input = self.micro_inputs[i](i, *inputs) 669 output = self.network(*micro_input) 670 if ret is not None: 671 ret = self.add_list[i](ret, output) 672 else: 673 ret = output 674 return ret 675 676class GradAccumulationCell(Cell): 677 """ 678 Wrap the network with Micro Batch to enable the grad accumulation in semi_auto_parallel/auto_parallel mode. 679 680 Args: 681 network (Cell): The target network to wrap. 682 micro_size (int): MicroBatch size. 683 684 Supported Platforms: 685 ``Ascend`` ``GPU`` 686 687 Examples: 688 >>> import mindspore.nn as nn 689 >>> # Define the network structure of LeNet5. Refer to 690 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 691 >>> net = LeNet5() 692 >>> net = nn.GradAccumulationCell(net, 4) 693 """ 694 def __init__(self, network, micro_size): 695 super(GradAccumulationCell, self).__init__(auto_prefix=False) 696 self.network = network 697 self.micro_inputs = nn.CellList() 698 self.micro_size = micro_size 699 self.add_list = [] 700 if not isinstance(network, Cell): 701 raise TypeError("For 'GradAccumulationCell', the argument 'network' must cell type, " 702 "but got the type : {}.".format(type(network))) 703 if not isinstance(micro_size, int): 704 raise TypeError("For 'GradAccumulationCell', the argument 'micro_size' must be integer, " 705 "but got the type : {}.".format(type(micro_size))) 706 if micro_size <= 0: 707 raise ValueError("For 'GradAccumulationCell', the argument 'micro_size' must be large than 0, " 708 "but got {}.".format(micro_size)) 709 for i in range(micro_size): 710 micro_input = _MicroBatch(micro_size) 711 micro_input.strided_slice.add_prim_attr("grad_accu_num", micro_size) 712 self.micro_inputs.append(micro_input) 713 self.add = P.Add().add_prim_attr("forward_end", i) 714 self.add_list.append(self.add) 715 self._get_attr_from_cell(network) 716 717 def construct(self, *inputs): 718 ret = None 719 for i in range(self.micro_size): 720 micro_input = self.micro_inputs[i](i, *inputs) 721 output = self.network(*micro_input) 722 if ret is not None: 723 ret = self.add_list[i](ret, output) 724 else: 725 ret = output 726 return ret 727 728 729def _pipeline_clear_grad(accu_grad, grad): 730 accu_grad = F.depend(accu_grad, grad) 731 zeros = F.zeros_like(accu_grad) 732 return F.assign(accu_grad, zeros) 733 734 735class _TrainGradAccuStepCell(TrainOneStepCell): 736 """ 737 Wraps the network with an optimizer in pipeline mode. 738 """ 739 def __init__(self, network, optimizer, sens=None): 740 super(_TrainGradAccuStepCell, self).__init__(network, optimizer, sens) 741 self.accu_grads = self.weights.clone(prefix="accu_grads", init="zeros") 742 self.hyper_map = ops.HyperMap() 743 self.opt_shard = _get_enable_parallel_optimizer() 744 self._get_attr_from_cell(network) 745 self.enable_mindio = False 746 mode = get_context("mode") 747 device_type = get_context("device_target") 748 if device_type != "Ascend" or mode != GRAPH_MODE: 749 return 750 graceful_exit = os.getenv("MS_ENABLE_MINDIO_GRACEFUL_EXIT") 751 ttp_lib_path = os.getenv("MS_MINDIO_TTP_LIB_PATH") 752 ttp_path_check = ttp_lib_path is not None and os.path.isfile(ttp_lib_path) 753 if graceful_exit == "true" and ttp_path_check: 754 self.g_one = Tensor([0.1]) 755 self.allreduce_sum = ops.AllReduce() 756 self.enable_mindio = True 757 758 def construct(self, *inputs): 759 if not self.sense_flag: 760 return self._no_sens_impl(*inputs) 761 loss = self.network(*inputs) 762 sens = ops.fill(ops.DType()(loss), ops.Shape()(loss), self.sens) 763 grads = self.grad(self.network, self.weights)(*inputs, sens) 764 accu_grads = ops.depend(self.accu_grads, grads) 765 if self.enable_mindio: 766 g_one = ops.depend(self.g_one, accu_grads) 767 g_one_res = self.allreduce_sum(g_one) 768 accu_grads = ops.depend(accu_grads, g_one_res) 769 grads = ops.depend(grads, g_one_res) 770 if self.opt_shard: 771 succ = self.optimizer(grads) 772 else: 773 succ = self.optimizer(accu_grads) 774 loss = ops.depend(loss, succ) 775 clear = self.hyper_map(_pipeline_clear_grad, accu_grads, grads) 776 loss = ops.depend(loss, clear) 777 return loss 778 779 def _no_sens_impl(self, *inputs): 780 """construct implementation when the 'sens' parameter is passed in.""" 781 loss = self.network(*inputs) 782 grads = self.grad_no_sens(self.network, self.weights)(*inputs) 783 accu_grads = ops.depend(self.accu_grads, grads) 784 if self.enable_mindio: 785 g_one = ops.depend(self.g_one, accu_grads) 786 g_one_res = self.allreduce_sum(g_one) 787 accu_grads = ops.depend(accu_grads, g_one_res) 788 grads = ops.depend(grads, g_one_res) 789 if self.opt_shard: 790 succ = self.optimizer(grads) 791 else: 792 succ = self.optimizer(accu_grads) 793 loss = ops.depend(loss, succ) 794 clear = self.hyper_map(_pipeline_clear_grad, accu_grads, grads) 795 loss = ops.depend(loss, clear) 796 return loss 797 798 799class AllreduceGraph(Cell): 800 """ 801 A allreduce graph to broadcast parameters. 802 """ 803 def __init__(self, inputs, group_name): 804 super(AllreduceGraph, self).__init__() 805 self.input_num = len(inputs) 806 self.inputs = inputs 807 self.allreduces = [] 808 self.assigns = [] 809 for _ in range(self.input_num): 810 self.allreduces.append(ops.AllReduce(op="sum", group=group_name)) 811 self.assigns.append(ops.Assign()) 812 813 def construct(self): 814 for i in range(self.input_num): 815 res = self.allreduces[i](self.inputs[i]) 816 self.assigns[i](self.inputs[i], res) 817 return self.inputs 818 819 820class VirtualDatasetCellTriple(Cell): 821 """ 822 Wrap the network with virtual dataset to convert data parallel layout to model parallel layout. 823 824 VirtualDatasetCellTriple is a virtual Primitive, it does not exist in the final executing graph. Inputs and outputs 825 of VirtualDatasetCellTriple are distributed in data parallel pattern, tensor redistribution Primitives is inserted 826 dynamically during the graph compile process. 827 828 Note: 829 Only used in semi auto parallel and auto parallel mode. There are three inputs, as contrary to two inputs in 830 _VirtualDatasetCell. 831 832 Args: 833 backbone (Cell): The target network to wrap. 834 835 Examples: 836 >>> import mindspore.nn as nn 837 >>> # Define the network structure of LeNet5. Refer to 838 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 839 >>> net = LeNet5() 840 >>> net = nn.VirtualDatasetCellTriple(net) 841 """ 842 843 def __init__(self, backbone): 844 super(VirtualDatasetCellTriple, self).__init__(auto_prefix=False) 845 logger.warning("WARN_DEPRECATED: The usage of VirtualDatasetCellTriple is deprecated.") 846 self._backbone = backbone 847 self._get_attr_from_cell(backbone) 848 849 def construct(self, a, b, c): 850 return self._backbone(a, b, c) 851 852 853class WithEvalCell(Cell): 854 r""" 855 Wraps the forward network with the loss function. 856 857 It returns loss, forward output and label to calculate the metrics. 858 859 Args: 860 network (Cell): The forward network. 861 loss_fn (Cell): The loss function. 862 add_cast_fp32 (bool): Whether to adjust the data type to float32. Default: ``False`` . 863 864 Inputs: 865 - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`. 866 - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`. 867 868 Outputs: 869 Tuple(Tensor), containing a scalar loss Tensor, a network output Tensor of shape :math:`(N, \ldots)` 870 and a label Tensor of shape :math:`(N, \ldots)`. 871 872 Raises: 873 TypeError: If `add_cast_fp32` is not a bool. 874 875 Supported Platforms: 876 ``Ascend`` ``GPU`` ``CPU`` 877 878 Examples: 879 >>> import mindspore.nn as nn 880 >>> # Define a forward network without loss function, taking LeNet5 as an example. 881 >>> # Refer to https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 882 >>> net = LeNet5() 883 >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() 884 >>> eval_net = nn.WithEvalCell(net, loss_fn) 885 """ 886 887 def __init__(self, network, loss_fn, add_cast_fp32=False): 888 super(WithEvalCell, self).__init__(auto_prefix=False) 889 self._network = network 890 self._loss_fn = loss_fn 891 self.add_cast_fp32 = validator.check_value_type("add_cast_fp32", add_cast_fp32, [bool], self.cls_name) 892 self._get_attr_from_cell(network) 893 894 def construct(self, data, label): 895 outputs = self._network(data) 896 if self.add_cast_fp32: 897 label = F.mixed_precision_cast(mstype.float32, label) 898 outputs = F.cast(outputs, mstype.float32) 899 loss = self._loss_fn(outputs, label) 900 return loss, outputs, label 901 902 903class ParameterUpdate(Cell): 904 """ 905 Cell that updates parameter. 906 907 With this Cell, one can manually update `param` with the input `Tensor`. 908 909 Args: 910 param (Parameter): The parameter to be updated manually. 911 912 Inputs: 913 - **x** (Tensor) - A tensor whose shape and type are the same with `param`. 914 915 Outputs: 916 Tensor, the updated value. 917 918 Raises: 919 KeyError: If parameter with the specified name does not exist. 920 921 Supported Platforms: 922 ``Ascend`` ``GPU`` ``CPU`` 923 924 Examples: 925 >>> import numpy as np 926 >>> import mindspore 927 >>> from mindspore import nn, Tensor 928 >>> network = nn.Dense(3, 4) 929 >>> param = network.parameters_dict()['weight'] 930 >>> update = nn.ParameterUpdate(param) 931 >>> update.phase = "update_param" 932 >>> weight = Tensor(np.arange(12).reshape((4, 3)), mindspore.float32) 933 >>> output = update(weight) 934 >>> print(output) 935 [[ 0. 1. 2.] 936 [ 3. 4. 5.] 937 [ 6. 7. 8.] 938 [ 9. 10. 11.]] 939 """ 940 941 def __init__(self, param): 942 super(ParameterUpdate, self).__init__(auto_prefix=False) 943 if not isinstance(param, Parameter): 944 raise TypeError("For 'ParameterUpdate', 'param' must be 'Parameter', but got {}.".format(type(param))) 945 self._param = param 946 947 def construct(self, x): 948 F.assign(self._param, x) 949 return x 950 951 952class _BroadCastCell(Cell): 953 """ 954 Broadcast the parameters from device 0 to other devices. 955 956 Args: 957 params (list): The parameters of Net. 958 """ 959 960 def __init__(self, params): 961 super(_BroadCastCell, self).__init__() 962 from mindspore.communication.management import get_group_size, create_group 963 from mindspore import context 964 self.map_ = C.Map() 965 self.params = tuple(params) 966 if context.get_context("device_target") == "Ascend" and context.get_context("mode") != context.PYNATIVE_MODE: 967 rank_list = [id for id in range(0, get_group_size())] 968 create_group("BroadcastWorldGroup", rank_list) 969 self.broadcast = P.Broadcast(0, group="BroadcastWorldGroup") 970 else: 971 self.broadcast = P.Broadcast(0) 972 self.add_flags(skip_auto_parallel_compile=True) 973 974 def construct(self): 975 datatypes = self.map_(F.partial(_get_datatype), self.params) 976 params = self.map_(F.partial(_cast_datatype, mstype.float32), self.params) 977 params = self.broadcast(params) 978 new_params = self.map_(F.partial(_cast_datatype), datatypes, params) 979 return new_params 980