1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16import mindspore as ms 17import mindspore.context as context 18from mindspore.common.api import _cell_graph_executor 19from mindspore import Tensor, Parameter 20import mindspore.nn as nn 21from mindspore.nn import Cell, TrainOneStepCell, Momentum 22from mindspore.ops import operations as P 23 24 25class TwoInputBpropOperator(Cell): 26 def __init__(self): 27 super().__init__() 28 self.op = P.Mul() 29 self.bp = P.Add() 30 31 def construct(self, x, y): 32 return self.op(x, y) 33 34 def bprop(self, x, y, out, dout): 35 return self.bp(5, x), self.bp(y, 8) 36 37 38class ParallelFloorDivBpropNet(Cell): 39 def __init__(self, mul_size, test_size, strategy=None, strategy2=None): 40 super().__init__() 41 mul_np = np.full(mul_size, 0.5, dtype=np.float32) 42 floordiv_np = np.full(test_size, 0.1, dtype=np.float32) 43 self.mul_weight = Parameter(Tensor(mul_np), name="mul_weight") 44 self.floordiv_weight = Parameter(Tensor(floordiv_np), name="floordiv_weight") 45 self.mul = TwoInputBpropOperator() 46 self.floor_div = P.FloorDiv() 47 self.bn = nn.BatchNorm1d(num_features=96) 48 if strategy is not None: 49 self.mul.op.shard(strategy2) 50 self.mul.bp.shard(strategy2) 51 self.floor_div.shard(strategy) 52 53 def construct(self, inputs, label): 54 x = self.mul(inputs, self.mul_weight) 55 x = self.floor_div(x, self.floordiv_weight) 56 x = self.bn(x) 57 return x 58 59 60inputs_ = Tensor(np.random.randn(128, 96).astype(np.float32), dtype=ms.float32) 61label_ = Tensor(np.random.randn(128, 96).astype(np.float32), dtype=ms.float32) 62 63 64def compile_net(net): 65 optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 66 train_net = TrainOneStepCell(net, optimizer) 67 train_net.set_auto_parallel() 68 train_net.set_train() 69 _cell_graph_executor.compile(train_net, inputs_, label_) 70 context.reset_auto_parallel_context() 71 72 73def test_net(): 74 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 75 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4, global_rank=0) 76 strategy = ((4, 1), (4, 1)) 77 net = ParallelFloorDivBpropNet(mul_size=(128, 96), test_size=(128, 96), strategy=strategy, strategy2=strategy) 78 compile_net(net) 79