1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16import mindspore as ms 17import mindspore.context as context 18from mindspore import Tensor, Parameter 19import mindspore.nn as nn 20from mindspore.common.api import _cell_graph_executor 21from mindspore.nn import TrainOneStepCell, Momentum 22from mindspore.ops import operations as P 23 24 25class Net(nn.Cell): 26 def __init__(self, weight1, strategy1=None, strategy2=None, is_parameter=True): 27 super(Net, self).__init__() 28 self.shape = (8, 48, 64) 29 self.broadcast = P.BroadcastTo(self.shape).shard(strategy1) 30 self.mul = P.Mul().shard(strategy2) 31 if is_parameter: 32 self.weight1 = Parameter(weight1, "w1") 33 else: 34 self.weight1 = weight1 35 36 def construct(self, x): 37 out = self.broadcast(self.weight1) 38 out = self.mul(x, out) 39 return out 40 41 42class MatMulNet(nn.Cell): 43 def __init__(self, weight1, strategy1=None, strategy2=None, strategy3=None, is_parameter=True): 44 super(MatMulNet, self).__init__() 45 self.shape = (8, 64, 64) 46 self.broadcast = P.BroadcastTo(self.shape).shard(strategy1) 47 self.matmul = P.BatchMatMul().shard(strategy2) 48 self.mul = P.Mul().shard(strategy3) 49 if is_parameter: 50 self.weight1 = Parameter(weight1, "w1") 51 else: 52 self.weight1 = weight1 53 54 def construct(self, x1, x2): 55 out = self.broadcast(x2) 56 out = self.matmul(x1, out) 57 out = self.mul(out, self.weight1) 58 return out 59 60 61_w1 = Tensor(np.ones([1, 48, 64]), dtype=ms.float32) 62_x1 = Tensor(np.ones([8, 48, 64]), dtype=ms.float32) 63_x2 = Tensor(np.ones([64, 64]), dtype=ms.float32) 64 65 66def compile_net(net): 67 context.set_context(mode=context.GRAPH_MODE) 68 optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 69 train_net = TrainOneStepCell(net, optimizer) 70 train_net.set_auto_parallel() 71 train_net.set_train() 72 _cell_graph_executor.compile(train_net, _x1) 73 context.reset_auto_parallel_context() 74 75 76def compile_net2(net): 77 context.set_context(mode=context.GRAPH_MODE) 78 optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 79 train_net = TrainOneStepCell(net, optimizer) 80 train_net.set_auto_parallel() 81 train_net.set_train() 82 _cell_graph_executor.compile(train_net, _x1, _x2) 83 context.reset_auto_parallel_context() 84 85 86def test_BroadcastTo_parameter(): 87 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 88 strategy1 = ((1, 4, 2),) 89 strategy2 = ((1, 4, 2), (1, 4, 2)) 90 net = Net(_w1, strategy1, strategy2) 91 compile_net(net) 92 93 94def test_BroadcastTo_parameter_no_full(): 95 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 96 strategy1 = ((1, 2, 2),) 97 strategy2 = ((1, 4, 2), (1, 4, 2)) 98 net = Net(_w1, strategy1, strategy2) 99 compile_net(net) 100 101 102def test_BroadcastTo_auto_parallel(): 103 context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) 104 net = Net(_w1) 105 compile_net(net) 106 107 108def test_BroadcastTo_matmul(): 109 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 110 strategy1 = ((2, 4),) 111 strategy2 = ((1, 1, 2), (1, 2, 4)) 112 strategy3 = ((1, 2, 4), (1, 2, 4)) 113 net = MatMulNet(_w1, strategy1, strategy2, strategy3) 114 compile_net2(net) 115