1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16import mindspore as ms 17import mindspore.context as context 18from mindspore import Tensor, Parameter 19import mindspore.nn as nn 20from mindspore.common.api import _cell_graph_executor 21from mindspore.nn import TrainOneStepCell, Momentum 22from mindspore.ops import operations as P 23 24 25class Net(nn.Cell): 26 def __init__(self, mul_weight, axis=0, out_nums=1, strategy1=None, strategy2=None, strategy3=None): 27 super(Net, self).__init__() 28 self.split = P.Split(axis, out_nums).shard(strategy1) 29 self.mul = P.Mul().shard(strategy2) 30 self.matmul = P.MatMul(transpose_b=True).shard(strategy2) 31 self.matmul2 = P.MatMul().shard(strategy3) 32 self.weight = Parameter(mul_weight, "w1") 33 34 def construct(self, x): 35 out = self.mul(x, self.weight) 36 out1, out2, out3 = self.split(out) 37 out = self.matmul(out1, out2) 38 out = self.matmul2(out, out3) 39 return out 40 41 42class Net1(nn.Cell): 43 def __init__(self, mul_weight, axis=0, out_nums=1, strategy1=None, strategy2=None): 44 super(Net1, self).__init__() 45 self.split = P.Split(axis, out_nums).shard(strategy1) 46 self.mul = P.Mul().shard(strategy2) 47 self.weight = Parameter(mul_weight, "w1") 48 49 def construct(self, x): 50 out1, out2 = self.split(self.weight) 51 out = self.mul(x, out1) 52 out = self.mul(out, out2) 53 return out 54 55 56class Net2(nn.Cell): 57 def __init__(self, mul_weight, axis=0, out_nums=1, strategy1=None, strategy2=None): 58 super(Net2, self).__init__() 59 self.split = P.Split(axis, out_nums).shard(strategy1) 60 self.mul = P.Mul().shard(strategy2) 61 self.weight = Parameter(mul_weight, "w1") 62 63 def construct(self, x): 64 out = self.mul(x, self.weight) 65 out1, _ = self.split(out) 66 return out1 67 68 69_w = Tensor(np.ones([48, 64]), dtype=ms.float32) 70_x = Tensor(np.ones([48, 64]), dtype=ms.float32) 71 72_w1 = Tensor(np.ones([96, 64, 32]), dtype=ms.float32) 73_x1 = Tensor(np.ones([48, 64, 32]), dtype=ms.float32) 74 75_w2 = Tensor(np.ones([48, 64, 32]), dtype=ms.float32) 76 77def compile_net(net): 78 context.set_context(mode=context.GRAPH_MODE) 79 optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 80 train_net = TrainOneStepCell(net, optimizer) 81 train_net.set_auto_parallel() 82 train_net.set_train() 83 _cell_graph_executor.compile(train_net, _x) 84 context.reset_auto_parallel_context() 85 86 87def compile_net1(net): 88 context.set_context(mode=context.GRAPH_MODE) 89 optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 90 train_net = TrainOneStepCell(net, optimizer) 91 train_net.set_auto_parallel() 92 train_net.set_train() 93 _cell_graph_executor.compile(train_net, _x1) 94 context.reset_auto_parallel_context() 95 96 97def test_split_parameter(): 98 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 99 strategy1 = ((1, 4, 2),) 100 strategy2 = ((1, 4, 2), (1, 4, 2)) 101 net = Net1(_w1, 0, 2, strategy1, strategy2) 102 compile_net1(net) 103 104 105def test_split_parameter_no_full_split(): 106 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 107 strategy1 = ((1, 2, 2),) 108 strategy2 = ((1, 4, 2), (1, 4, 2)) 109 net = Net1(_w1, 0, 2, strategy1, strategy2) 110 compile_net1(net) 111 112 113def test_split_tensor(): 114 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 115 strategy1 = ((1, 8),) 116 strategy2 = ((1, 8), (1, 8)) 117 strategy3 = ((1, 1), (1, 8)) 118 net = Net(_w, 0, 3, strategy1, strategy2, strategy3) 119 compile_net(net) 120 121 122def test_split_output(): 123 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 124 strategy1 = ((1, 4, 2),) 125 strategy2 = ((1, 4, 2), (1, 4, 2)) 126 net = Net2(_w2, 0, 2, strategy1, strategy2) 127 compile_net1(net) 128 129 130def test_split_output_no_full_split(): 131 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 132 strategy1 = ((1, 2, 2),) 133 strategy2 = ((1, 4, 2), (1, 4, 2)) 134 net = Net2(_w2, 0, 2, strategy1, strategy2) 135 compile_net1(net) 136 137 138def test_split_no_strategy(): 139 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 140 strategy1 = None 141 strategy2 = ((1, 4, 2), (1, 4, 2)) 142 net = Net2(_w2, 0, 2, strategy1, strategy2) 143 compile_net1(net) 144 145 146def test_split_auto_parallel(): 147 context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) 148 net = Net2(_w2, 0, 2) 149 compile_net1(net) 150