# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ import numpy as np import mindspore as ms import mindspore.nn as nn from mindspore import Tensor from mindspore import context from mindspore.common.api import _cell_graph_executor from mindspore.ops import composite as C from mindspore.ops import operations as P from mindspore.parallel import set_algo_parameters from mindspore.ops.operations._inner_ops import DSDMatmul from tests.ut.python.ops.test_math_ops import VirtualLoss context.set_context(mode=context.GRAPH_MODE) grad_all = C.GradOperation(get_all=True) # input_w1, the shape is (batch_size, head, block_num, head_size // 16, block_size//16, 16, 16) # input_w1 cum_shape = batch_size * seq_len * embedding_size * (block_size // size_per_head) # = batch_size * seq_len * (embedding_size // 2) # input_w2, the shape is (batch_size, head, block_num, global_size // 16, head_size // 16, 16, 16) # input_w2 cum_shape = batch_size * seq_len * embedding_size * (global_size // size_per_head) # = batch_size * seq_len * embedding_size * 2 # input_v, the shape is (batch_size * seq_len // 16, head * v_embedding // 16, 16, 16) # block_num = seq_len // block_size, block_size = 64, head * v_embedding = embedding_size, always. # output shape is (batch_size, head, v_embedding // 16, seq_len//16, 16, 16) class Net(nn.Cell): def __init__(self, batch_size, num_heads, dp, mp, shard=True): super(Net, self).__init__() self.batch_size = batch_size self.num_heads = num_heads self.seq_len = 1024 self.block_size = 64 self.head_size = self.block_size self.block_num = self.seq_len // self.block_size self.global_size = 256 self.v_embedding = 128 self.embedding_size = num_heads * self.v_embedding self.dsd_matmul = DSDMatmul() self.reduce_sum = P.ReduceSum() self.dense1 = nn.Dense(self.embedding_size, self.embedding_size // 2, has_bias=False) self.dense2 = nn.Dense(self.embedding_size, self.embedding_size * 2, has_bias=False) self.dense3 = nn.Dense(self.embedding_size, self.embedding_size, has_bias=False) self.reshape = P.Reshape() self.transpose = P.Transpose() self.transpose1 = P.Transpose() self.add = P.Add() if shard: self.dsd_matmul.shard(((dp, mp, 1, 1, 1, 1, 1), (dp, mp, 1, 1, 1, 1, 1), (dp, mp, 1, 1))) self.dense1.matmul.shard(((dp, 1), (mp, 1))) self.dense2.matmul.shard(((dp, 1), (mp, 1))) self.dense2.matmul.shard(((dp, 1), (mp, 1))) self.transpose.shard(((dp, 1, mp, 1),)) self.transpose1.shard(((dp, mp, 1, 1, 1, 1),)) def construct(self, x): # x (batch_size * seq_len, embedding_size) q = self.dense1(x) # q (batch_size * seq_len, (embedding_size // 2)) # (batch_size, head, block_num, head_size // 16, block_size//16, 16, 16) k = self.dense2(x) # k (batch_size * seq_len, (embedding_size * 2)) # (batch_size, head, block_num, global_size // 16, head_size // 16, 16, 16) v = self.dense3(x) # v (batch_size * seq_len, embedding_size) q = self.reshape(q, (self.batch_size, self.num_heads, self.block_num, self.head_size // 16, self.block_size // 16, 16, 16)) k = self.reshape(k, (self.batch_size, self.num_heads, self.block_num, self.global_size // 16, self.head_size // 16, 16, 16)) v = self.transpose(self.reshape(v, (-1, 16, self.embedding_size // 16, 16)), (0, 2, 3, 1)) dsd = self.dsd_matmul(q, k, v) # dsd (batch_size, head, v_embedding // 16, seq_len//16, 16, 16) dsd = self.transpose1(dsd, (0, 1, 3, 4, 2, 5)) # dsd (batch_size, head, seq_len//16, 16, v_embedding_size//16, 16) dsd = self.reshape(dsd, (-1, self.seq_len, self.v_embedding * self.num_heads)) result = self.reduce_sum(dsd, 2) return result class GradWrap(nn.Cell): def __init__(self, network): super(GradWrap, self).__init__() self.network = network def construct(self, x): return grad_all(self.network)(x) class NetWithLoss(nn.Cell): def __init__(self, network): super(NetWithLoss, self).__init__() self.network = network self.loss = VirtualLoss() def construct(self, x): predict = self.network(x) return self.loss(predict) def compile_graph(batch_size, num_heads, dp, mp, auto=False, shard=True): if auto: context.set_auto_parallel_context(parallel_mode="auto_parallel") else: context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") x = Tensor(np.ones((batch_size * 1024, num_heads * 128)), ms.float32) net = GradWrap(NetWithLoss(Net(batch_size, num_heads, dp, mp, shard=shard))) net.set_auto_parallel() net.set_train() _cell_graph_executor.compile(net, x) def test_dsd_matmul_model_parallel_mix(): context.set_auto_parallel_context(device_num=16, global_rank=0) batch_size = 128 num_heads = 32 dp = 2 mp = 8 compile_graph(batch_size, num_heads, dp, mp) def test_dsd_matmul_model_parallel_dp(): context.set_auto_parallel_context(device_num=16, global_rank=0) batch_size = 128 num_heads = 32 dp = 16 mp = 1 compile_graph(batch_size, num_heads, dp, mp) def test_dsd_matmul_model_parallel_mp(): context.set_auto_parallel_context(device_num=16, global_rank=0) batch_size = 128 num_heads = 32 dp = 1 mp = 16 compile_graph(batch_size, num_heads, dp, mp) def test_dsd_matmul_model_parallel_mix_auto(): set_algo_parameters(fully_use_devices=False) context.set_auto_parallel_context(device_num=16, global_rank=0) batch_size = 128 num_heads = 32 dp = 2 mp = 8 compile_graph(batch_size, num_heads, dp, mp, auto=True) def test_dsd_matmul_model_parallel_dp_auto(): context.set_auto_parallel_context(device_num=16, global_rank=0) batch_size = 128 num_heads = 32 dp = 16 mp = 1 compile_graph(batch_size, num_heads, dp, mp, auto=True) def test_dsd_matmul_model_parallel_mp_auto(): context.set_auto_parallel_context(device_num=16, global_rank=0) batch_size = 128 num_heads = 32 dp = 1 mp = 16 compile_graph(batch_size, num_heads, dp, mp, auto=True) def test_dsd_matmul_model_parallel_auto(): set_algo_parameters(fully_use_devices=False) context.set_auto_parallel_context(device_num=16, global_rank=0) batch_size = 128 num_heads = 32 dp = 1 mp = 16 compile_graph(batch_size, num_heads, dp, mp, auto=True, shard=False)