1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16import mindspore as ms 17import mindspore.context as context 18from mindspore import Tensor, Parameter 19import mindspore.nn as nn 20from mindspore.common.api import _cell_graph_executor 21from mindspore.nn import TrainOneStepCell, Momentum 22from mindspore.ops import operations as P 23 24class Net(nn.Cell): 25 def __init__(self, wi, wo, stra1=None, stra2=None, stra3=None, stra4=None, 26 stra5=None, stra6=None): 27 super(Net, self).__init__() 28 self.relu = P.ReLU().shard(stra1) 29 self.transpose = P.Transpose().shard(stra2) 30 self.wi = Parameter(wi, "wi") 31 self.batch_mm = P.BatchMatMul().shard(stra3) 32 self.wo = Parameter(wo, "wo") 33 self.batch_mm2 = P.BatchMatMul().shard(stra4) 34 self.transpose2 = P.Transpose().shard(stra5) 35 self.relu2 = P.ReLU().shard(stra6) 36 self.reshape = P.Reshape() 37 self.reshape2 = P.Reshape() 38 39 def construct(self, x): 40 output = self.relu(x) 41 trans_out = self.transpose(output, (2, 0, 3, 1)) 42 output = self.reshape(trans_out, 43 (trans_out.shape[0], trans_out.shape[1]*trans_out.shape[2], trans_out.shape[3])) 44 output = self.batch_mm(output, self.wi) 45 output = self.batch_mm2(output, self.wo) 46 output = self.reshape2(output, trans_out.shape) 47 output = self.transpose2(output, (1, 3, 0, 2)) 48 output = self.relu2(output) 49 return output 50 51_x = Tensor(np.ones([32, 16, 48, 128]), dtype=ms.float32) 52_wi = Tensor(np.ones([48, 16, 64]), dtype=ms.float32) 53_wo = Tensor(np.ones([48, 64, 16]), dtype=ms.float32) 54 55 56def compile_net(net): 57 context.set_context(mode=context.GRAPH_MODE) 58 optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 59 train_net = TrainOneStepCell(net, optimizer) 60 train_net.set_auto_parallel() 61 train_net.set_train() 62 _cell_graph_executor.compile(train_net, _x) 63 context.reset_auto_parallel_context() 64 65 66def test_batchmm(): 67 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, enable_alltoall=True, 68 global_rank=0) 69 stra1 = ((8, 1, 1, 1),) 70 stra2 = ((8, 1, 1, 1),) 71 stra3 = ((8, 1, 1), (8, 1, 1)) 72 stra4 = ((8, 1, 1), (8, 1, 1)) 73 stra5 = ((8, 1, 1, 1),) 74 stra6 = ((8, 1, 1, 1),) 75 net = Net(_wi, _wo, stra1=stra1, stra2=stra2, stra3=stra3, stra4=stra4, stra5=stra5, stra6=stra6) 76 compile_net(net) 77 78 79def test_batchmm2(): 80 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", enable_alltoall=True, 81 device_num=32, global_rank=0) 82 stra1 = ((4, 1, 1, 1),) 83 stra2 = ((4, 1, 1, 1),) 84 stra3 = ((4, 1, 1), (4, 1, 8)) 85 stra4 = ((4, 1, 8), (4, 8, 1)) 86 stra5 = ((4, 1, 1, 1),) 87 stra6 = ((4, 1, 1, 1),) 88 net = Net(_wi, _wo, stra1=stra1, stra2=stra2, stra3=stra3, stra4=stra4, stra5=stra5, stra6=stra6) 89 compile_net(net) 90