• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import numpy as np
17
18import mindspore as ms
19import mindspore.nn as nn
20from mindspore import Tensor
21from mindspore import context
22from mindspore.common.api import _cell_graph_executor
23from mindspore.ops import composite as C
24from mindspore.ops import operations as P
25from mindspore.parallel import set_algo_parameters
26from mindspore.ops.operations._inner_ops import DSDMatmul
27from tests.ut.python.ops.test_math_ops import VirtualLoss
28
29context.set_context(mode=context.GRAPH_MODE)
30
31grad_all = C.GradOperation(get_all=True)
32
33
34#  input_w1, the shape is (batch_size, head, block_num, head_size // 16, block_size//16, 16, 16)
35#  input_w1 cum_shape = batch_size * seq_len * embedding_size * (block_size // size_per_head)
36#  = batch_size * seq_len * (embedding_size // 2)
37#  input_w2, the shape is (batch_size, head, block_num, global_size // 16, head_size // 16, 16, 16)
38#  input_w2 cum_shape = batch_size * seq_len * embedding_size * (global_size // size_per_head)
39#  = batch_size * seq_len * embedding_size * 2
40#  input_v, the shape is (batch_size * seq_len // 16, head * v_embedding // 16, 16, 16)
41#  block_num = seq_len // block_size, block_size = 64, head * v_embedding = embedding_size, always.
42#  output shape is (batch_size, head, v_embedding // 16, seq_len//16, 16, 16)
43
44
45class Net(nn.Cell):
46    def __init__(self, batch_size, num_heads, dp, mp, shard=True):
47        super(Net, self).__init__()
48        self.batch_size = batch_size
49        self.num_heads = num_heads
50        self.seq_len = 1024
51        self.block_size = 64
52        self.head_size = self.block_size
53        self.block_num = self.seq_len // self.block_size
54        self.global_size = 256
55        self.v_embedding = 128
56        self.embedding_size = num_heads * self.v_embedding
57        self.dsd_matmul = DSDMatmul()
58        self.reduce_sum = P.ReduceSum()
59        self.dense1 = nn.Dense(self.embedding_size, self.embedding_size // 2, has_bias=False)
60        self.dense2 = nn.Dense(self.embedding_size, self.embedding_size * 2, has_bias=False)
61        self.dense3 = nn.Dense(self.embedding_size, self.embedding_size, has_bias=False)
62        self.reshape = P.Reshape()
63        self.transpose = P.Transpose()
64        self.transpose1 = P.Transpose()
65        self.add = P.Add()
66        if shard:
67            self.dsd_matmul.shard(((dp, mp, 1, 1, 1, 1, 1), (dp, mp, 1, 1, 1, 1, 1), (dp, mp, 1, 1)))
68            self.dense1.matmul.shard(((dp, 1), (mp, 1)))
69            self.dense2.matmul.shard(((dp, 1), (mp, 1)))
70            self.dense2.matmul.shard(((dp, 1), (mp, 1)))
71            self.transpose.shard(((dp, 1, mp, 1),))
72            self.transpose1.shard(((dp, mp, 1, 1, 1, 1),))
73
74    def construct(self, x):
75        # x (batch_size * seq_len, embedding_size)
76        q = self.dense1(x)
77        # q (batch_size * seq_len, (embedding_size // 2))
78        # (batch_size, head, block_num, head_size // 16, block_size//16, 16, 16)
79        k = self.dense2(x)
80        # k (batch_size * seq_len, (embedding_size * 2))
81        # (batch_size, head, block_num, global_size // 16, head_size // 16, 16, 16)
82        v = self.dense3(x)
83        # v (batch_size * seq_len, embedding_size)
84        q = self.reshape(q, (self.batch_size, self.num_heads, self.block_num, self.head_size // 16,
85                             self.block_size // 16, 16, 16))
86        k = self.reshape(k, (self.batch_size, self.num_heads, self.block_num, self.global_size // 16,
87                             self.head_size // 16, 16, 16))
88        v = self.transpose(self.reshape(v, (-1, 16, self.embedding_size // 16, 16)), (0, 2, 3, 1))
89        dsd = self.dsd_matmul(q, k, v)
90        # dsd (batch_size, head, v_embedding // 16, seq_len//16, 16, 16)
91        dsd = self.transpose1(dsd, (0, 1, 3, 4, 2, 5))
92        # dsd (batch_size, head, seq_len//16, 16, v_embedding_size//16, 16)
93        dsd = self.reshape(dsd, (-1, self.seq_len, self.v_embedding * self.num_heads))
94        result = self.reduce_sum(dsd, 2)
95        return result
96
97
98class GradWrap(nn.Cell):
99    def __init__(self, network):
100        super(GradWrap, self).__init__()
101        self.network = network
102
103    def construct(self, x):
104        return grad_all(self.network)(x)
105
106
107class NetWithLoss(nn.Cell):
108    def __init__(self, network):
109        super(NetWithLoss, self).__init__()
110        self.network = network
111        self.loss = VirtualLoss()
112
113    def construct(self, x):
114        predict = self.network(x)
115        return self.loss(predict)
116
117
118def compile_graph(batch_size, num_heads, dp, mp, auto=False, shard=True):
119    if auto:
120        context.set_auto_parallel_context(parallel_mode="auto_parallel")
121    else:
122        context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
123    x = Tensor(np.ones((batch_size * 1024, num_heads * 128)), ms.float32)
124    net = GradWrap(NetWithLoss(Net(batch_size, num_heads, dp, mp, shard=shard)))
125    net.set_auto_parallel()
126    net.set_train()
127    _cell_graph_executor.compile(net, x)
128
129def test_dsd_matmul_model_parallel_mix():
130    context.set_auto_parallel_context(device_num=16, global_rank=0)
131    batch_size = 128
132    num_heads = 32
133    dp = 2
134    mp = 8
135    compile_graph(batch_size, num_heads, dp, mp)
136
137def test_dsd_matmul_model_parallel_dp():
138    context.set_auto_parallel_context(device_num=16, global_rank=0)
139    batch_size = 128
140    num_heads = 32
141    dp = 16
142    mp = 1
143    compile_graph(batch_size, num_heads, dp, mp)
144
145def test_dsd_matmul_model_parallel_mp():
146    context.set_auto_parallel_context(device_num=16, global_rank=0)
147    batch_size = 128
148    num_heads = 32
149    dp = 1
150    mp = 16
151    compile_graph(batch_size, num_heads, dp, mp)
152
153def test_dsd_matmul_model_parallel_mix_auto():
154    set_algo_parameters(fully_use_devices=False)
155    context.set_auto_parallel_context(device_num=16, global_rank=0)
156    batch_size = 128
157    num_heads = 32
158    dp = 2
159    mp = 8
160    compile_graph(batch_size, num_heads, dp, mp, auto=True)
161
162def test_dsd_matmul_model_parallel_dp_auto():
163    context.set_auto_parallel_context(device_num=16, global_rank=0)
164    batch_size = 128
165    num_heads = 32
166    dp = 16
167    mp = 1
168    compile_graph(batch_size, num_heads, dp, mp, auto=True)
169
170def test_dsd_matmul_model_parallel_mp_auto():
171    context.set_auto_parallel_context(device_num=16, global_rank=0)
172    batch_size = 128
173    num_heads = 32
174    dp = 1
175    mp = 16
176    compile_graph(batch_size, num_heads, dp, mp, auto=True)
177
178def test_dsd_matmul_model_parallel_auto():
179    set_algo_parameters(fully_use_devices=False)
180    context.set_auto_parallel_context(device_num=16, global_rank=0)
181    batch_size = 128
182    num_heads = 32
183    dp = 1
184    mp = 16
185    compile_graph(batch_size, num_heads, dp, mp, auto=True, shard=False)
186