1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Cell_wrapper.""" 16from types import FunctionType, MethodType 17 18from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean, 19 _get_parallel_mode, _get_enable_parallel_optimizer) 20from mindspore.context import ParallelMode 21from mindspore._checkparam import Validator as validator 22from mindspore import ops, nn 23from ...common import dtype as mstype 24from ...common.parameter import Parameter, ParameterTuple 25from ...ops import composite as C 26from ...ops import functional as F 27from ...ops import operations as P 28from ...ops.operations.comm_ops import _VirtualDataset 29from ..cell import Cell 30from .grad_reducer import DistributedGradReducer 31 32_get_datatype = C.MultitypeFuncGraph("_get_datatype") 33 34 35@_get_datatype.register("Tensor") 36def _tensors_get_datatype(param): 37 """ 38 Acquire parameter datatype. 39 40 Args: 41 param (Tensor): The parameter before operation. 42 43 Returns: 44 mstype, the datatype of parameter. 45 """ 46 return F.dtype(param) 47 48 49_cast_datatype = C.MultitypeFuncGraph("_cast_datatype") 50 51 52@_cast_datatype.register("TypeType", "Tensor") 53def _tensors_cast_datatype(datatype, param): 54 """ 55 Cast gradient to datatype. 56 57 Args: 58 datatype (mstype): the destination datatype of parameter. 59 param (Tensor): The parameter before operation. 60 61 Returns: 62 Tensor, the parameter after operation. 63 """ 64 return F.cast(param, datatype) 65 66 67 68class WithLossCell(Cell): 69 r""" 70 Cell with loss function. 71 72 Wraps the network with loss function. This Cell accepts data and label as inputs and 73 the computed loss will be returned. 74 75 Args: 76 backbone (Cell): The target network to wrap. 77 loss_fn (Cell): The loss function used to compute loss. 78 79 Inputs: 80 - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`. 81 - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`. 82 83 Outputs: 84 Tensor, a tensor means the loss value, the shape of which is usually :math:`()`. 85 86 Raises: 87 TypeError: If dtype of `data` or `label` is neither float16 nor float32. 88 89 Supported Platforms: 90 ``Ascend`` ``GPU`` ``CPU`` 91 92 Examples: 93 >>> net = Net() 94 >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False) 95 >>> net_with_criterion = nn.WithLossCell(net, loss_fn) 96 >>> 97 >>> batch_size = 2 98 >>> data = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32) * 0.01) 99 >>> label = Tensor(np.ones([batch_size, 10]).astype(np.float32)) 100 >>> 101 >>> output_data = net_with_criterion(data, label) 102 """ 103 104 def __init__(self, backbone, loss_fn): 105 super(WithLossCell, self).__init__(auto_prefix=False) 106 self._backbone = backbone 107 self._loss_fn = loss_fn 108 109 def construct(self, data, label): 110 out = self._backbone(data) 111 return self._loss_fn(out, label) 112 113 @property 114 def backbone_network(self): 115 """ 116 Get the backbone network. 117 118 Returns: 119 Cell, the backbone network. 120 """ 121 return self._backbone 122 123 124class WithGradCell(Cell): 125 r""" 126 Cell that returns the gradients. 127 128 Wraps the network with backward cell to compute gradients. A network with a loss function is necessary 129 as argument. If loss function in None, the network must be a wrapper of network and loss function. This 130 Cell accepts '\*inputs' as inputs and returns gradients for each trainable parameter. 131 132 Note: 133 Run in PyNative mode. 134 135 Args: 136 network (Cell): The target network to wrap. The network only supports single output. 137 loss_fn (Cell): Primitive loss function used to compute gradients. Default: None. 138 sens (Union[None, Tensor, Scalar, Tuple ...]): The sensitive for backpropagation, the type and shape 139 must be same as the `network` output. If None, we will fill one to a same type shape of 140 output value. Default: None. 141 142 Inputs: 143 - **(\*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`. 144 145 Outputs: 146 list, a list of Tensors with identical shapes as trainable weights. 147 148 Raises: 149 TypeError: If `sens` is not one of None, Tensor, Scalar or Tuple. 150 151 Supported Platforms: 152 ``Ascend`` ``GPU`` ``CPU`` 153 154 Examples: 155 >>> # For a defined network Net without loss function 156 >>> net = Net() 157 >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() 158 >>> grad_net = nn.WithGradCell(net, loss_fn) 159 >>> 160 >>> # For a network wrapped with loss function 161 >>> net = Net() 162 >>> net_with_criterion = nn.WithLossCell(net, loss_fn) 163 >>> grad_net = nn.WithGradCell(net_with_criterion) 164 """ 165 166 def __init__(self, network, loss_fn=None, sens=None): 167 super(WithGradCell, self).__init__(auto_prefix=False) 168 self.network = network 169 self.loss_fn = loss_fn 170 self.weights = ParameterTuple(network.trainable_params()) 171 self.grad = C.GradOperation(get_by_list=True, sens_param=(sens is not None)) 172 self.sens = sens 173 if loss_fn is None: 174 self.network_with_loss = network 175 else: 176 self.network_with_loss = WithLossCell(self.network, self.loss_fn) 177 self.network_with_loss.set_train() 178 179 def construct(self, *inputs): 180 weights = self.weights 181 if self.sens is None: 182 grads = self.grad(self.network_with_loss, weights)(*inputs) 183 else: 184 grads = self.grad(self.network_with_loss, weights)(*inputs, self.sens) 185 return grads 186 187 188class ForwardValueAndGrad(Cell): 189 r""" 190 Network training package class. 191 192 Including the network and a gradient function. The resulting Cell is trained with input '\*inputs'. 193 The backward graph will be created in the gradient function to calculating gradient. 194 195 Args: 196 network (Cell): The training network. 197 weights (ParameterTuple): The parameters of the training network that need to calculate the gradient. 198 get_all (bool): If True, get all the gradients with respect to inputs. Default: False. 199 get_by_list (bool): If True, get all the gradients with respect to Parameter variables. 200 If get_all and get_by_list are both False, get the gradient with respect to first input. 201 If get_all and get_by_list are both True, get the gradients with respect to inputs and Parameter variables 202 at the same time in the form of ((gradients with respect to inputs), 203 (gradients with respect to parameters)). Default: False. 204 sens_param (bool): Whether to append sensitivity (gradient with respect to output) as input. 205 If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically. 206 Default: False. 207 If the sens_param is True, a sensitivity (gradient with respect to output) needs to be transferred through 208 the input parameter. 209 210 Inputs: 211 - **(\*inputs)** (Tuple(Tensor...)) - Tuple of inputs with shape :math:`(N, \ldots)`. 212 - **(sens)** - A sensitivity (gradient with respect to output) as the input of backpropagation. 213 If network has single output, the sens is a tensor. 214 If network has multiple outputs, the sens is the tuple(tensor). 215 216 Outputs: 217 - **forward value** - The result of network forward running. 218 - **gradients** (tuple(tensor)) - The gradients of network parameters and inputs. 219 220 Supported Platforms: 221 ``Ascend`` ``GPU`` ``CPU`` 222 223 Examples: 224 >>> class Net(nn.Cell): 225 ... def __init__(self): 226 ... super(Net, self).__init__() 227 ... self.weight = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="weight") 228 ... self.matmul = P.MatMul() 229 ... 230 ... def construct(self, x): 231 ... out = self.matmul(x, self.weight) 232 ... return out 233 ... 234 >>> net = Net() 235 >>> criterion = nn.SoftmaxCrossEntropyWithLogits() 236 >>> net_with_criterion = nn.WithLossCell(net, criterion) 237 >>> weight = ParameterTuple(net.trainable_params()) 238 >>> train_network = nn.ForwardValueAndGrad(net_with_criterion, weights=weight, get_all=True, get_by_list=True) 239 >>> inputs = Tensor(np.ones([1, 2]).astype(np.float32)) 240 >>> labels = Tensor(np.zeros([1, 2]).astype(np.float32)) 241 >>> result = train_network(inputs, labels) 242 >>> print(result) 243 (Tensor(shape=[1], dtype=Float32, value=[0.00000000e+00]), ((Tensor(shape=[1, 2], dtype=Float32, value= 244 [[1.00000000e+00, 1.00000000e+00]]), Tensor(shape=[1, 2], dtype=Float32, value= 245 [[0.00000000e+00, 0.00000000e+00]])), (Tensor(shape=[2, 2], dtype=Float32, value= 246 [[5.00000000e-01, 5.00000000e-01], 247 [5.00000000e-01, 5.00000000e-01]]),))) 248 """ 249 250 def __init__(self, network, weights=None, get_all=False, get_by_list=False, sens_param=False): 251 super(ForwardValueAndGrad, self).__init__(auto_prefix=False) 252 if not isinstance(network, (Cell, FunctionType, MethodType)): 253 raise TypeError(f"The type of training network should be cell, function type or method type, " 254 f"but got '{type(network)}'") 255 if not isinstance(get_all, bool): 256 raise TypeError(f"The type of get_all should be bool, but got '{type(get_all)}'") 257 if not isinstance(get_by_list, bool): 258 raise TypeError(f"The type of get_by_list should be bool, but got '{type(get_by_list)}'") 259 if get_by_list and not isinstance(weights, ParameterTuple): 260 raise TypeError(f"When get_by_list is set to True, the parameters of training network should be " 261 f"ParameterTuple type, but got '{type(weights)}'") 262 self.network = network 263 if isinstance(network, Cell): 264 self.network.set_grad() 265 self.weights = weights 266 self.get_all = get_all 267 self.get_by_list = get_by_list 268 self.sens_param = sens_param 269 self.grad = C.GradOperation(get_all=self.get_all, get_by_list=self.get_by_list, sens_param=self.sens_param) 270 271 def construct(self, *inputs): 272 grad_inputs = inputs 273 if self.sens_param: 274 inputs = inputs[:-1] 275 loss = self.network(*inputs) 276 if self.get_by_list: 277 grads = self.grad(self.network, self.weights)(*grad_inputs) 278 else: 279 grads = self.grad(self.network)(*grad_inputs) 280 return loss, grads 281 282 283class TrainOneStepCell(Cell): 284 r""" 285 Network training package class. 286 287 Wraps the network with an optimizer. The resulting Cell is trained with input '\*inputs'. 288 The backward graph will be created in the construct function to update the parameter. Different 289 parallel modes are available for training. 290 291 Args: 292 network (Cell): The training network. The network only supports single output. 293 optimizer (Union[Cell]): Optimizer for updating the weights. 294 sens (numbers.Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. 295 296 Inputs: 297 - **(\*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`. 298 299 Outputs: 300 Tensor, a tensor means the loss value, the shape of which is usually :math:`()`. 301 302 Raises: 303 TypeError: If `sens` is not a number. 304 305 Supported Platforms: 306 ``Ascend`` ``GPU`` ``CPU`` 307 308 Examples: 309 >>> net = Net() 310 >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() 311 >>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 312 >>> #1) Using the WithLossCell existing provide 313 >>> loss_net = nn.WithLossCell(net, loss_fn) 314 >>> train_net = nn.TrainOneStepCell(loss_net, optim) 315 >>> 316 >>> #2) Using user-defined WithLossCell 317 >>> class MyWithLossCell(Cell): 318 ... def __init__(self, backbone, loss_fn): 319 ... super(MyWithLossCell, self).__init__(auto_prefix=False) 320 ... self._backbone = backbone 321 ... self._loss_fn = loss_fn 322 ... 323 ... def construct(self, x, y, label): 324 ... out = self._backbone(x, y) 325 ... return self._loss_fn(out, label) 326 ... 327 ... @property 328 ... def backbone_network(self): 329 ... return self._backbone 330 ... 331 >>> loss_net = MyWithLossCell(net, loss_fn) 332 >>> train_net = nn.TrainOneStepCell(loss_net, optim) 333 """ 334 335 def __init__(self, network, optimizer, sens=1.0): 336 super(TrainOneStepCell, self).__init__(auto_prefix=False) 337 self.network = network 338 self.network.set_grad() 339 self.optimizer = optimizer 340 self.weights = self.optimizer.parameters 341 self.grad = C.GradOperation(get_by_list=True, sens_param=True) 342 self.sens = sens 343 self.reducer_flag = False 344 self.grad_reducer = F.identity 345 self.parallel_mode = _get_parallel_mode() 346 self.reducer_flag = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL) 347 if self.reducer_flag: 348 self.mean = _get_gradients_mean() 349 self.degree = _get_device_num() 350 self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree) 351 352 def construct(self, *inputs): 353 loss = self.network(*inputs) 354 sens = F.fill(loss.dtype, loss.shape, self.sens) 355 grads = self.grad(self.network, self.weights)(*inputs, sens) 356 grads = self.grad_reducer(grads) 357 loss = F.depend(loss, self.optimizer(grads)) 358 return loss 359 360 361class GetNextSingleOp(Cell): 362 """ 363 Cell to run for getting the next operation. 364 365 Args: 366 dataset_types (list[:class:`mindspore.dtype`]): The types of dataset. 367 dataset_shapes (list[tuple[int]]): The shapes of dataset. 368 queue_name (str): Queue name to fetch the data. 369 370 For detailed information, refer to `ops.operations.GetNext`. 371 372 Inputs: 373 No inputs. 374 375 Outputs: 376 tuple[Tensor], the data get from Dataset. 377 378 Supported Platforms: 379 ``Ascend`` ``GPU`` 380 381 Examples: 382 >>> train_dataset = create_custom_dataset() 383 >>> dataset_helper = mindspore.DatasetHelper(train_dataset, dataset_sink_mode=True) 384 >>> dataset = dataset_helper.iter.dataset 385 >>> dataset_types, dataset_shapes = dataset_helper.types_shapes() 386 >>> queue_name = dataset.__transfer_dataset__.queue_name 387 >>> get_next_single_op_net = nn.GetNextSingleOp(dataset_types, dataset_shapes, queue_name) 388 >>> data, label = get_next_single_op_net() 389 >>> relu = P.ReLU() 390 >>> result = relu(data).asnumpy() 391 >>> print(result.shape) 392 (32, 1, 32, 32) 393 """ 394 395 def __init__(self, dataset_types, dataset_shapes, queue_name): 396 super(GetNextSingleOp, self).__init__() 397 self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name) 398 399 def construct(self): 400 return self.get_next() 401 402 403class _VirtualDatasetCell(Cell): 404 """ 405 Wrap the network with virtual dataset to convert data parallel layout to model parallel layout. 406 407 _VirtualDataset is a virtual Primitive, it does not exist in the final executing graph. Inputs and outputs 408 of _VirtualDataset are distributed in data parallel pattern, tensor redistribution Primitives is inserted 409 dynamically during the graph compile process. 410 411 Note: 412 Only used in semi auto parallel and auto parallel mode. 413 414 Args: 415 backbone (Cell): The target network to wrap. 416 417 Examples: 418 >>> net = Net() 419 >>> net = _VirtualDatasetCell(net) 420 """ 421 422 def __init__(self, backbone): 423 super(_VirtualDatasetCell, self).__init__(auto_prefix=False) 424 self._backbone = backbone 425 self._virtual_dataset = _VirtualDataset() 426 427 def construct(self, *inputs): 428 output = self._virtual_dataset(*inputs) 429 return self._backbone(*output) 430 431 432class _MicroBatch(Cell): 433 """ 434 transform mini-batch to micro-batch in pipeline parallel. 435 436 Args: 437 params (micro_size): The number of micro-batch. 438 """ 439 def __init__(self, micro_size): 440 super(_MicroBatch, self).__init__() 441 self.shape = P.Shape() 442 self.micro_size = micro_size 443 444 def construct(self, i, *inputs): 445 micro_inputs = () 446 for each_input in inputs: 447 input_shape = self.shape(each_input) 448 micro_batch_begin = i * input_shape[0] // self.micro_size 449 micro_batch_end = (i + 1) * input_shape[0] // self.micro_size 450 micro_input = each_input[micro_batch_begin:micro_batch_end] 451 micro_inputs += (micro_input,) 452 return micro_inputs 453 454 455class PipelineCell(Cell): 456 """ 457 Wrap the network with Micro Batch. 458 459 Note: 460 micro_size must be greater or equal to pipeline stages. 461 462 Args: 463 network (Cell): The target network to wrap. 464 micro_size (int): MicroBatch size. 465 466 Examples: 467 >>> net = Net() 468 >>> net = PipelineCell(net, 4) 469 """ 470 def __init__(self, network, micro_size): 471 super(PipelineCell, self).__init__(auto_prefix=False) 472 self.network = network 473 self.micro_inputs = nn.CellList() 474 self.micro_size = micro_size 475 self.add_list = [] 476 for i in range(micro_size): 477 micro_input = _MicroBatch(micro_size) 478 self.micro_inputs.append(micro_input) 479 self.add = P.Add().add_prim_attr("pipeline_end", i) 480 self.add_list.append(self.add) 481 482 def construct(self, *inputs): 483 ret = None 484 for i in range(self.micro_size): 485 micro_input = self.micro_inputs[i](i, *inputs) 486 output = self.network(*micro_input) 487 if ret is not None: 488 ret = self.add_list[i](ret, output) 489 else: 490 ret = output 491 return ret 492 493 494def _pipeline_clear_grad(accu_grad, grad): 495 accu_grad = F.depend(accu_grad, grad) 496 zeros = F.tensor_mul(accu_grad, 0.0) 497 return F.assign(accu_grad, zeros) 498 499 500class _TrainPipelineAccuStepCell(TrainOneStepCell): 501 """ 502 Wraps the network with an optimizer in pipeline mode. 503 """ 504 def __init__(self, network, optimizer, sens=1.0): 505 super(_TrainPipelineAccuStepCell, self).__init__(network, optimizer, sens) 506 self.accu_grads = self.weights.clone(prefix="accu_grads", init="zeros") 507 self.hyper_map = ops.HyperMap() 508 self.opt_shard = _get_enable_parallel_optimizer() 509 510 def construct(self, *inputs): 511 weights = self.weights 512 loss = self.network(*inputs) 513 sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens) 514 grads = self.grad(self.network, weights)(*inputs, sens) 515 accu_grads = ops.depend(self.accu_grads, grads) 516 if self.opt_shard: 517 succ = self.optimizer(grads) 518 else: 519 succ = self.optimizer(accu_grads) 520 loss = ops.depend(loss, succ) 521 clear = self.hyper_map(_pipeline_clear_grad, accu_grads, grads) 522 loss = ops.depend(loss, clear) 523 return loss 524 525 526class VirtualDatasetCellTriple(Cell): 527 """ 528 Wrap the network with virtual dataset to convert data parallel layout to model parallel layout. 529 530 VirtualDatasetCellTriple is a virtual Primitive, it does not exist in the final executing graph. Inputs and outputs 531 of VirtualDatasetCellTriple are distributed in data parallel pattern, tensor redistribution Primitives is inserted 532 dynamically during the graph compile process. 533 534 Note: 535 Only used in semi auto parallel and auto parallel mode. There are three inputs, as contrary to two inputs in 536 _VirtualDatasetCell. 537 538 Args: 539 backbone (Cell): The target network to wrap. 540 541 Examples: 542 >>> net = Net() 543 >>> net = VirtualDatasetCellTriple(net) 544 """ 545 546 def __init__(self, backbone): 547 super(VirtualDatasetCellTriple, self).__init__(auto_prefix=False) 548 self._backbone = backbone 549 self._virtual_dataset = _VirtualDataset() 550 551 def construct(self, a, b, c): 552 a_, b_, c_ = self._virtual_dataset(a, b, c) 553 return self._backbone(a_, b_, c_) 554 555 556class WithEvalCell(Cell): 557 r""" 558 Cell that returns loss, output and label for evaluation. 559 560 This Cell accepts a network and loss function as arguments and computes loss for model. 561 It returns loss, output and label to calculate the metrics. 562 563 Args: 564 network (Cell): The network Cell. 565 loss_fn (Cell): The loss Cell. 566 add_cast_fp32 (bool): Adjust the data type to float32. 567 568 Inputs: 569 - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`. 570 - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`. 571 572 Outputs: 573 Tuple, containing a scalar loss Tensor, a network output Tensor of shape :math:`(N, \ldots)` 574 and a label Tensor of shape :math:`(N, \ldots)`. 575 576 Raises: 577 TypeError: If `add_cast_fp32` is not a bool. 578 579 Supported Platforms: 580 ``Ascend`` ``GPU`` ``CPU`` 581 582 Examples: 583 >>> # For a defined network Net without loss function 584 >>> net = Net() 585 >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() 586 >>> eval_net = nn.WithEvalCell(net, loss_fn) 587 """ 588 589 def __init__(self, network, loss_fn, add_cast_fp32=False): 590 super(WithEvalCell, self).__init__(auto_prefix=False) 591 self._network = network 592 self._loss_fn = loss_fn 593 self.add_cast_fp32 = validator.check_value_type("add_cast_fp32", add_cast_fp32, [bool], self.cls_name) 594 595 def construct(self, data, label): 596 outputs = self._network(data) 597 if self.add_cast_fp32: 598 label = F.mixed_precision_cast(mstype.float32, label) 599 outputs = F.cast(outputs, mstype.float32) 600 loss = self._loss_fn(outputs, label) 601 return loss, outputs, label 602 603 604class ParameterUpdate(Cell): 605 """ 606 Cell that updates parameter. 607 608 With this Cell, one can manually update `param` with the input `Tensor`. 609 610 Args: 611 param (Parameter): The parameter to be updated manually. 612 613 Inputs: 614 - **x** (Tensor) - A tensor whose shape and type are the same with `param`. 615 616 Outputs: 617 Tensor, the input `x`. 618 619 Raises: 620 KeyError: If parameter with the specified name does not exist. 621 622 Supported Platforms: 623 ``Ascend`` ``GPU`` ``CPU`` 624 625 Examples: 626 >>> network = nn.Dense(3, 4) 627 >>> param = network.parameters_dict()['weight'] 628 >>> update = nn.ParameterUpdate(param) 629 >>> update.phase = "update_param" 630 >>> weight = Tensor(np.arange(12).reshape((4, 3)), mindspore.float32) 631 >>> output = update(weight) 632 """ 633 634 def __init__(self, param): 635 super(ParameterUpdate, self).__init__(auto_prefix=False) 636 if not isinstance(param, Parameter): 637 raise TypeError("`param` must be `Parameter`, but got {}".format(param)) 638 self._param = param 639 640 def construct(self, x): 641 F.assign(self._param, x) 642 return x 643 644 645class _BroadCastCell(Cell): 646 """ 647 Broadcast the parameters from device 0 to other devices. 648 649 Args: 650 params (list): The parameters of Net. 651 """ 652 653 def __init__(self, params): 654 super(_BroadCastCell, self).__init__() 655 from mindspore.communication.management import get_group_size, create_group 656 from mindspore import context 657 self.map_ = C.Map() 658 self.params = tuple(params) 659 if context.get_context("device_target") == "Ascend" and context.get_context("mode") != context.PYNATIVE_MODE: 660 rank_list = [id for id in range(0, get_group_size())] 661 create_group("BroadcastWorldGroup", rank_list) 662 self.broadcast = P.Broadcast(0, group="BroadcastWorldGroup") 663 else: 664 self.broadcast = P.Broadcast(0) 665 666 def construct(self): 667 datatypes = self.map_(F.partial(_get_datatype), self.params) 668 params = self.map_(F.partial(_cast_datatype, mstype.float32), self.params) 669 params = self.broadcast(params) 670 new_params = self.map_(F.partial(_cast_datatype), datatypes, params) 671 return new_params 672