1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import numpy as np 17 18import mindspore as ms 19import mindspore.nn as nn 20from mindspore import Tensor 21from mindspore import context 22from mindspore.common.api import _cell_graph_executor 23from mindspore.ops import composite as C 24from mindspore.ops import operations as P 25from mindspore.parallel import set_algo_parameters 26from mindspore.ops.operations._inner_ops import MatmulDDS 27from tests.ut.python.ops.test_math_ops import VirtualLoss 28 29context.set_context(mode=context.GRAPH_MODE) 30 31grad_all = C.GradOperation(get_all=True) 32 33# q: (num_heads * size_per_head // 16, bs * seq_len // 16, 16, 16) 34# k: (num_heads * size_per_head // 16, bs * seq_len // 16, 16, 16) 35# local_mask: (block_num * block_size // 16, bs * block_size // 16, 16, 16) 36# global_mask: (bs * global_size // 16, seq_len // 16, 16, 16) 37# local_prob: (bs, num_heads, block_num, block_size // 16, block_size // 16, 16, 16) 38# global_prob: (bs, num_heads, block_num, global_size // 16, block_size // 16, 16, 16) 39# x: (bs*seq_len, num_heads*size_per_head) 40class Net(nn.Cell): 41 def __init__(self, batch_size, num_heads, dp, mp, shard=True): 42 super(Net, self).__init__() 43 self.batch_size = batch_size 44 self.num_heads = num_heads 45 self.size_per_head = 128 46 self.seq_len = 1024 47 self.block_size = 64 48 self.block_num = self.seq_len // self.block_size 49 self.global_size = 256 50 self.embedding_size = num_heads * self.size_per_head 51 self.cus_matmul = MatmulDDS(batch_size, num_heads) 52 self.reduce_sum = P.ReduceSum() 53 self.global_mask = Tensor(np.ones((batch_size * self.global_size // 16, self.seq_len // 16, 16, 16))) 54 self.local_mask = Tensor(np.ones((self.block_num * self.block_size // 16, 55 batch_size * self.block_size // 16, 16, 16))) 56 self.dense1 = nn.Dense(self.embedding_size, self.embedding_size, has_bias=False) 57 self.dense2 = nn.Dense(self.embedding_size, self.embedding_size, has_bias=False) 58 self.reshape = P.Reshape() 59 self.transpose = P.Transpose() 60 self.add = P.Add() 61 if shard: 62 self.cus_matmul.shard(((mp, dp, 1, 1), (mp, dp, 1, 1), (1, dp, 1, 1), (dp, 1, 1, 1))) 63 self.dense1.matmul.shard(((dp, 1), (mp, 1))) 64 self.dense2.matmul.shard(((dp, 1), (mp, 1))) 65 self.transpose.shard(((dp, 1, mp, 1),)) 66 67 68 def construct(self, x): 69 q = self.dense1(x) 70 k = self.dense2(x) 71 q = self.transpose(self.reshape(q, (-1, 16, self.embedding_size // 16, 16)), (2, 0, 1, 3)) 72 k = self.transpose(self.reshape(k, (-1, 16, self.embedding_size // 16, 16)), (2, 0, 1, 3)) 73 local_prob, global_prob = self.cus_matmul(q, k, self.local_mask, self.global_mask) 74 local_prob = self.reshape(local_prob, (self.batch_size, self.num_heads, -1)) 75 global_prob = self.reshape(global_prob, (self.batch_size, self.num_heads, -1)) 76 local_prob_reduce = self.reduce_sum(local_prob, 2) 77 global_prob_reduce = self.reduce_sum(global_prob, 2) 78 result = self.add(local_prob_reduce, global_prob_reduce) 79 return result 80 81 82class GradWrap(nn.Cell): 83 def __init__(self, network): 84 super(GradWrap, self).__init__() 85 self.network = network 86 87 def construct(self, x): 88 return grad_all(self.network)(x) 89 90 91class NetWithLoss(nn.Cell): 92 def __init__(self, network): 93 super(NetWithLoss, self).__init__() 94 self.network = network 95 self.loss = VirtualLoss() 96 97 def construct(self, x): 98 predict = self.network(x) 99 return self.loss(predict) 100 101 102def compile_graph(batch_size, num_heads, dp, mp, auto=False, shard=True): 103 if auto: 104 context.set_auto_parallel_context(parallel_mode="auto_parallel") 105 else: 106 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 107 x = Tensor(np.ones((batch_size * 1024, num_heads * 128)), ms.float32) 108 net = GradWrap(NetWithLoss(Net(batch_size, num_heads, dp, mp, shard=shard))) 109 net.set_auto_parallel() 110 net.set_train() 111 _cell_graph_executor.compile(net, x) 112 113def test_cus_matmul_dds_model_parallel_mix(): 114 context.set_auto_parallel_context(device_num=16, global_rank=0) 115 batch_size = 128 116 num_heads = 32 117 dp = 2 118 mp = 8 119 compile_graph(batch_size, num_heads, dp, mp) 120 121def test_cus_matmul_dds_model_parallel_dp(): 122 context.set_auto_parallel_context(device_num=16, global_rank=0) 123 batch_size = 128 124 num_heads = 32 125 dp = 16 126 mp = 1 127 compile_graph(batch_size, num_heads, dp, mp) 128 129def test_cus_matmul_dds_model_parallel_mp(): 130 context.set_auto_parallel_context(device_num=16, global_rank=0) 131 batch_size = 128 132 num_heads = 32 133 dp = 1 134 mp = 16 135 compile_graph(batch_size, num_heads, dp, mp) 136 137def test_cus_matmul_dds_model_parallel_mix_auto(): 138 set_algo_parameters(fully_use_devices=False) 139 context.set_auto_parallel_context(device_num=16, global_rank=0) 140 batch_size = 128 141 num_heads = 32 142 dp = 2 143 mp = 8 144 compile_graph(batch_size, num_heads, dp, mp, auto=True) 145 146def test_cus_matmul_dds_model_parallel_dp_auto(): 147 context.set_auto_parallel_context(device_num=16, global_rank=0) 148 batch_size = 128 149 num_heads = 32 150 dp = 16 151 mp = 1 152 compile_graph(batch_size, num_heads, dp, mp, auto=True) 153 154def test_cus_matmul_dds_model_parallel_mp_auto(): 155 context.set_auto_parallel_context(device_num=16, global_rank=0) 156 batch_size = 128 157 num_heads = 32 158 dp = 1 159 mp = 16 160 compile_graph(batch_size, num_heads, dp, mp, auto=True) 161 162def test_cus_matmul_dds_model_parallel_auto(): 163 set_algo_parameters(fully_use_devices=False) 164 context.set_auto_parallel_context(device_num=16, global_rank=0) 165 batch_size = 128 166 num_heads = 32 167 dp = 1 168 mp = 16 169 compile_graph(batch_size, num_heads, dp, mp, auto=True, shard=False) 170 171def test_cus_matmul_dds_repeat_cal_auto(): 172 set_algo_parameters(fully_use_devices=False) 173 context.set_auto_parallel_context(device_num=16, global_rank=0) 174 batch_size = 128 175 num_heads = 32 176 dp = 1 177 mp = 2 178 compile_graph(batch_size, num_heads, dp, mp, auto=True, shard=False) 179 180def test_cus_matmul_dds_repeat1_cal_auto(): 181 set_algo_parameters(fully_use_devices=False) 182 context.set_auto_parallel_context(device_num=16, global_rank=0) 183 batch_size = 128 184 num_heads = 32 185 dp = 2 186 mp = 1 187 compile_graph(batch_size, num_heads, dp, mp, auto=True, shard=False) 188