1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import numpy as np 17 18import mindspore as ms 19import mindspore.nn as nn 20from mindspore import Tensor 21from mindspore import context 22from mindspore.common.api import _cell_graph_executor 23from mindspore.ops import composite as C 24from mindspore.ops import operations as P 25from mindspore.parallel import set_algo_parameters 26from mindspore.ops.operations._inner_ops import DSDMatmul 27from tests.ut.python.ops.test_math_ops import VirtualLoss 28 29context.set_context(mode=context.GRAPH_MODE) 30 31grad_all = C.GradOperation(get_all=True) 32 33 34# input_w1, the shape is (batch_size, head, block_num, head_size // 16, block_size//16, 16, 16) 35# input_w1 cum_shape = batch_size * seq_len * embedding_size * (block_size // size_per_head) 36# = batch_size * seq_len * (embedding_size // 2) 37# input_w2, the shape is (batch_size, head, block_num, global_size // 16, head_size // 16, 16, 16) 38# input_w2 cum_shape = batch_size * seq_len * embedding_size * (global_size // size_per_head) 39# = batch_size * seq_len * embedding_size * 2 40# input_v, the shape is (batch_size * seq_len // 16, head * v_embedding // 16, 16, 16) 41# block_num = seq_len // block_size, block_size = 64, head * v_embedding = embedding_size, always. 42# output shape is (batch_size, head, v_embedding // 16, seq_len//16, 16, 16) 43 44 45class Net(nn.Cell): 46 def __init__(self, batch_size, num_heads, dp, mp, shard=True): 47 super(Net, self).__init__() 48 self.batch_size = batch_size 49 self.num_heads = num_heads 50 self.seq_len = 1024 51 self.block_size = 64 52 self.head_size = self.block_size 53 self.block_num = self.seq_len // self.block_size 54 self.global_size = 256 55 self.v_embedding = 128 56 self.embedding_size = num_heads * self.v_embedding 57 self.dsd_matmul = DSDMatmul() 58 self.reduce_sum = P.ReduceSum() 59 self.dense1 = nn.Dense(self.embedding_size, self.embedding_size // 2, has_bias=False) 60 self.dense2 = nn.Dense(self.embedding_size, self.embedding_size * 2, has_bias=False) 61 self.dense3 = nn.Dense(self.embedding_size, self.embedding_size, has_bias=False) 62 self.reshape = P.Reshape() 63 self.transpose = P.Transpose() 64 self.transpose1 = P.Transpose() 65 self.add = P.Add() 66 if shard: 67 self.dsd_matmul.shard(((dp, mp, 1, 1, 1, 1, 1), (dp, mp, 1, 1, 1, 1, 1), (dp, mp, 1, 1))) 68 self.dense1.matmul.shard(((dp, 1), (mp, 1))) 69 self.dense2.matmul.shard(((dp, 1), (mp, 1))) 70 self.dense2.matmul.shard(((dp, 1), (mp, 1))) 71 self.transpose.shard(((dp, 1, mp, 1),)) 72 self.transpose1.shard(((dp, mp, 1, 1, 1, 1),)) 73 74 def construct(self, x): 75 # x (batch_size * seq_len, embedding_size) 76 q = self.dense1(x) 77 # q (batch_size * seq_len, (embedding_size // 2)) 78 # (batch_size, head, block_num, head_size // 16, block_size//16, 16, 16) 79 k = self.dense2(x) 80 # k (batch_size * seq_len, (embedding_size * 2)) 81 # (batch_size, head, block_num, global_size // 16, head_size // 16, 16, 16) 82 v = self.dense3(x) 83 # v (batch_size * seq_len, embedding_size) 84 q = self.reshape(q, (self.batch_size, self.num_heads, self.block_num, self.head_size // 16, 85 self.block_size // 16, 16, 16)) 86 k = self.reshape(k, (self.batch_size, self.num_heads, self.block_num, self.global_size // 16, 87 self.head_size // 16, 16, 16)) 88 v = self.transpose(self.reshape(v, (-1, 16, self.embedding_size // 16, 16)), (0, 2, 3, 1)) 89 dsd = self.dsd_matmul(q, k, v) 90 # dsd (batch_size, head, v_embedding // 16, seq_len//16, 16, 16) 91 dsd = self.transpose1(dsd, (0, 1, 3, 4, 2, 5)) 92 # dsd (batch_size, head, seq_len//16, 16, v_embedding_size//16, 16) 93 dsd = self.reshape(dsd, (-1, self.seq_len, self.v_embedding * self.num_heads)) 94 result = self.reduce_sum(dsd, 2) 95 return result 96 97 98class GradWrap(nn.Cell): 99 def __init__(self, network): 100 super(GradWrap, self).__init__() 101 self.network = network 102 103 def construct(self, x): 104 return grad_all(self.network)(x) 105 106 107class NetWithLoss(nn.Cell): 108 def __init__(self, network): 109 super(NetWithLoss, self).__init__() 110 self.network = network 111 self.loss = VirtualLoss() 112 113 def construct(self, x): 114 predict = self.network(x) 115 return self.loss(predict) 116 117 118def compile_graph(batch_size, num_heads, dp, mp, auto=False, shard=True): 119 if auto: 120 context.set_auto_parallel_context(parallel_mode="auto_parallel") 121 else: 122 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 123 x = Tensor(np.ones((batch_size * 1024, num_heads * 128)), ms.float32) 124 net = GradWrap(NetWithLoss(Net(batch_size, num_heads, dp, mp, shard=shard))) 125 net.set_auto_parallel() 126 net.set_train() 127 _cell_graph_executor.compile(net, x) 128 129def test_dsd_matmul_model_parallel_mix(): 130 context.set_auto_parallel_context(device_num=16, global_rank=0) 131 batch_size = 128 132 num_heads = 32 133 dp = 2 134 mp = 8 135 compile_graph(batch_size, num_heads, dp, mp) 136 137def test_dsd_matmul_model_parallel_dp(): 138 context.set_auto_parallel_context(device_num=16, global_rank=0) 139 batch_size = 128 140 num_heads = 32 141 dp = 16 142 mp = 1 143 compile_graph(batch_size, num_heads, dp, mp) 144 145def test_dsd_matmul_model_parallel_mp(): 146 context.set_auto_parallel_context(device_num=16, global_rank=0) 147 batch_size = 128 148 num_heads = 32 149 dp = 1 150 mp = 16 151 compile_graph(batch_size, num_heads, dp, mp) 152 153def test_dsd_matmul_model_parallel_mix_auto(): 154 set_algo_parameters(fully_use_devices=False) 155 context.set_auto_parallel_context(device_num=16, global_rank=0) 156 batch_size = 128 157 num_heads = 32 158 dp = 2 159 mp = 8 160 compile_graph(batch_size, num_heads, dp, mp, auto=True) 161 162def test_dsd_matmul_model_parallel_dp_auto(): 163 context.set_auto_parallel_context(device_num=16, global_rank=0) 164 batch_size = 128 165 num_heads = 32 166 dp = 16 167 mp = 1 168 compile_graph(batch_size, num_heads, dp, mp, auto=True) 169 170def test_dsd_matmul_model_parallel_mp_auto(): 171 context.set_auto_parallel_context(device_num=16, global_rank=0) 172 batch_size = 128 173 num_heads = 32 174 dp = 1 175 mp = 16 176 compile_graph(batch_size, num_heads, dp, mp, auto=True) 177 178def test_dsd_matmul_model_parallel_auto(): 179 set_algo_parameters(fully_use_devices=False) 180 context.set_auto_parallel_context(device_num=16, global_rank=0) 181 batch_size = 128 182 num_heads = 32 183 dp = 1 184 mp = 16 185 compile_graph(batch_size, num_heads, dp, mp, auto=True, shard=False) 186