• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import numpy as np
17
18import mindspore as ms
19import mindspore.nn as nn
20from mindspore import Tensor
21from mindspore import context
22from mindspore.common.api import _cell_graph_executor
23from mindspore.ops import composite as C
24from mindspore.ops import operations as P
25from mindspore.parallel import set_algo_parameters
26from mindspore.ops.operations._inner_ops import MatmulDDS
27from tests.ut.python.ops.test_math_ops import VirtualLoss
28
29context.set_context(mode=context.GRAPH_MODE)
30
31grad_all = C.GradOperation(get_all=True)
32
33# q: (num_heads * size_per_head // 16, bs * seq_len // 16, 16, 16)
34# k: (num_heads * size_per_head // 16, bs * seq_len // 16, 16, 16)
35# local_mask: (block_num * block_size // 16, bs * block_size // 16, 16, 16)
36# global_mask: (bs * global_size // 16, seq_len // 16, 16, 16)
37# local_prob: (bs, num_heads, block_num, block_size // 16, block_size // 16, 16, 16)
38# global_prob: (bs, num_heads, block_num, global_size // 16, block_size // 16, 16, 16)
39# x: (bs*seq_len, num_heads*size_per_head)
40class Net(nn.Cell):
41    def __init__(self, batch_size, num_heads, dp, mp, shard=True):
42        super(Net, self).__init__()
43        self.batch_size = batch_size
44        self.num_heads = num_heads
45        self.size_per_head = 128
46        self.seq_len = 1024
47        self.block_size = 64
48        self.block_num = self.seq_len // self.block_size
49        self.global_size = 256
50        self.embedding_size = num_heads * self.size_per_head
51        self.cus_matmul = MatmulDDS(batch_size, num_heads)
52        self.reduce_sum = P.ReduceSum()
53        self.global_mask = Tensor(np.ones((batch_size * self.global_size // 16, self.seq_len // 16, 16, 16)))
54        self.local_mask = Tensor(np.ones((self.block_num * self.block_size // 16,
55                                          batch_size * self.block_size // 16, 16, 16)))
56        self.dense1 = nn.Dense(self.embedding_size, self.embedding_size, has_bias=False)
57        self.dense2 = nn.Dense(self.embedding_size, self.embedding_size, has_bias=False)
58        self.reshape = P.Reshape()
59        self.transpose = P.Transpose()
60        self.add = P.Add()
61        if shard:
62            self.cus_matmul.shard(((mp, dp, 1, 1), (mp, dp, 1, 1), (1, dp, 1, 1), (dp, 1, 1, 1)))
63            self.dense1.matmul.shard(((dp, 1), (mp, 1)))
64            self.dense2.matmul.shard(((dp, 1), (mp, 1)))
65            self.transpose.shard(((dp, 1, mp, 1),))
66
67
68    def construct(self, x):
69        q = self.dense1(x)
70        k = self.dense2(x)
71        q = self.transpose(self.reshape(q, (-1, 16, self.embedding_size // 16, 16)), (2, 0, 1, 3))
72        k = self.transpose(self.reshape(k, (-1, 16, self.embedding_size // 16, 16)), (2, 0, 1, 3))
73        local_prob, global_prob = self.cus_matmul(q, k, self.local_mask, self.global_mask)
74        local_prob = self.reshape(local_prob, (self.batch_size, self.num_heads, -1))
75        global_prob = self.reshape(global_prob, (self.batch_size, self.num_heads, -1))
76        local_prob_reduce = self.reduce_sum(local_prob, 2)
77        global_prob_reduce = self.reduce_sum(global_prob, 2)
78        result = self.add(local_prob_reduce, global_prob_reduce)
79        return result
80
81
82class GradWrap(nn.Cell):
83    def __init__(self, network):
84        super(GradWrap, self).__init__()
85        self.network = network
86
87    def construct(self, x):
88        return grad_all(self.network)(x)
89
90
91class NetWithLoss(nn.Cell):
92    def __init__(self, network):
93        super(NetWithLoss, self).__init__()
94        self.network = network
95        self.loss = VirtualLoss()
96
97    def construct(self, x):
98        predict = self.network(x)
99        return self.loss(predict)
100
101
102def compile_graph(batch_size, num_heads, dp, mp, auto=False, shard=True):
103    if auto:
104        context.set_auto_parallel_context(parallel_mode="auto_parallel")
105    else:
106        context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
107    x = Tensor(np.ones((batch_size * 1024, num_heads * 128)), ms.float32)
108    net = GradWrap(NetWithLoss(Net(batch_size, num_heads, dp, mp, shard=shard)))
109    net.set_auto_parallel()
110    net.set_train()
111    _cell_graph_executor.compile(net, x)
112
113def test_cus_matmul_dds_model_parallel_mix():
114    context.set_auto_parallel_context(device_num=16, global_rank=0)
115    batch_size = 128
116    num_heads = 32
117    dp = 2
118    mp = 8
119    compile_graph(batch_size, num_heads, dp, mp)
120
121def test_cus_matmul_dds_model_parallel_dp():
122    context.set_auto_parallel_context(device_num=16, global_rank=0)
123    batch_size = 128
124    num_heads = 32
125    dp = 16
126    mp = 1
127    compile_graph(batch_size, num_heads, dp, mp)
128
129def test_cus_matmul_dds_model_parallel_mp():
130    context.set_auto_parallel_context(device_num=16, global_rank=0)
131    batch_size = 128
132    num_heads = 32
133    dp = 1
134    mp = 16
135    compile_graph(batch_size, num_heads, dp, mp)
136
137def test_cus_matmul_dds_model_parallel_mix_auto():
138    set_algo_parameters(fully_use_devices=False)
139    context.set_auto_parallel_context(device_num=16, global_rank=0)
140    batch_size = 128
141    num_heads = 32
142    dp = 2
143    mp = 8
144    compile_graph(batch_size, num_heads, dp, mp, auto=True)
145
146def test_cus_matmul_dds_model_parallel_dp_auto():
147    context.set_auto_parallel_context(device_num=16, global_rank=0)
148    batch_size = 128
149    num_heads = 32
150    dp = 16
151    mp = 1
152    compile_graph(batch_size, num_heads, dp, mp, auto=True)
153
154def test_cus_matmul_dds_model_parallel_mp_auto():
155    context.set_auto_parallel_context(device_num=16, global_rank=0)
156    batch_size = 128
157    num_heads = 32
158    dp = 1
159    mp = 16
160    compile_graph(batch_size, num_heads, dp, mp, auto=True)
161
162def test_cus_matmul_dds_model_parallel_auto():
163    set_algo_parameters(fully_use_devices=False)
164    context.set_auto_parallel_context(device_num=16, global_rank=0)
165    batch_size = 128
166    num_heads = 32
167    dp = 1
168    mp = 16
169    compile_graph(batch_size, num_heads, dp, mp, auto=True, shard=False)
170
171def test_cus_matmul_dds_repeat_cal_auto():
172    set_algo_parameters(fully_use_devices=False)
173    context.set_auto_parallel_context(device_num=16, global_rank=0)
174    batch_size = 128
175    num_heads = 32
176    dp = 1
177    mp = 2
178    compile_graph(batch_size, num_heads, dp, mp, auto=True, shard=False)
179
180def test_cus_matmul_dds_repeat1_cal_auto():
181    set_algo_parameters(fully_use_devices=False)
182    context.set_auto_parallel_context(device_num=16, global_rank=0)
183    batch_size = 128
184    num_heads = 32
185    dp = 2
186    mp = 1
187    compile_graph(batch_size, num_heads, dp, mp, auto=True, shard=False)
188