• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numpy as np
16import pytest
17import mindspore.common.dtype as mstype
18import mindspore.nn as nn
19from mindspore import Tensor
20from mindspore.context import set_auto_parallel_context, ParallelMode
21from mindspore.ops import composite as C
22from mindspore.ops import functional as F
23import mindspore.ops as P
24from mindspore.parallel.nn import TransformerEncoder, TransformerDecoder, Transformer, TransformerOpParallelConfig, \
25    VocabEmbedding, CrossEntropyLoss, OpParallelConfig, EmbeddingOpParallelConfig, FixedSparseAttention
26from mindspore.nn import Dense as Linear
27from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
28from mindspore.nn.optim import AdamWeightDecay
29from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell, TrainOneStepCell
30from mindspore.nn.wrap.loss_scale import _TrainPipelineWithLossScaleCell
31from mindspore.train import Model
32from mindspore.parallel import set_algo_parameters
33from tests.dataset_mock import MindData
34from tests.ut.python.ops.test_math_ops import VirtualLoss
35
36grad_all = C.GradOperation(get_all=True)
37
38
39class Dataset(MindData):
40    def __init__(self, *inputs, length=3):
41        super(Dataset, self).__init__(size=length)
42        self.inputs = inputs
43        self.index = 0
44        self.length = length
45
46    def __iter__(self):
47        return self
48
49    def __next__(self):
50        if self.index >= self.length:
51            raise StopIteration
52        self.index += 1
53        return self.inputs
54
55    def reset(self):
56        self.index = 0
57
58
59class TransformerNet(nn.Cell):
60    def __init__(self, en_layer, de_layer, parallel_config):
61        super(TransformerNet, self).__init__()
62        self.embedding = VocabEmbedding(vocab_size=240, embedding_size=20,
63                                        parallel_config=config.embedding_dp_mp_config)
64        self.network = Transformer(encoder_layers=en_layer,
65                                   decoder_layers=de_layer,
66                                   batch_size=2,
67                                   src_seq_length=20,
68                                   tgt_seq_length=10,
69                                   hidden_size=64,
70                                   num_heads=8,
71                                   ffn_hidden_size=64,
72                                   parallel_config=parallel_config)
73        self.head = Linear(in_channels=64, out_channels=200)
74        self.loss = CrossEntropyLoss(parallel_config=config.dp_mp_config)
75
76    def construct(self, x1, x2, x3, x4, x5, y, mask):
77        predict, _, _ = self.network(x1, x2, x3, x4, x5)
78        predict = P.Reshape()(predict, (-1, F.shape(predict)[-1]))
79        return self.loss(predict, y, mask)
80
81config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, vocab_emb_dp=False)
82pipeline_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, pipeline_stage=4,
83                                              micro_batch_num=4, vocab_emb_dp=False)
84
85
86class NetWithLossFiveInputs(nn.Cell):
87    def __init__(self, network):
88        super(NetWithLossFiveInputs, self).__init__()
89        self.loss = VirtualLoss()
90        self.network = network
91
92    def construct(self, x1, x2, x3, x4, x5):
93        predict, _, _ = self.network(x1, x2, x3, x4, x5)
94        return self.loss(predict)
95
96
97def run_total_transformer_model_head(e_layer,
98                                     d_layer,
99                                     arg_parallel_config,
100                                     mode=ParallelMode.SEMI_AUTO_PARALLEL):
101    dp = arg_parallel_config.data_parallel
102    mp = arg_parallel_config.model_parallel
103    pp = arg_parallel_config.pipeline_stage
104    if dp * mp * pp != 1:
105        set_auto_parallel_context(device_num=8,
106                                  full_batch=True,
107                                  global_rank=0, parallel_mode=mode)
108
109    encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
110    encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
111    decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
112    decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
113    memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
114    seq = 20
115    if d_layer > 0:
116        seq = 10
117    label = Tensor(np.ones((2 * seq,)), mstype.int32)
118    input_mask = Tensor(np.ones((2 * seq,)), mstype.float32)
119    net = TransformerNet(en_layer=e_layer, de_layer=d_layer, parallel_config=arg_parallel_config)
120    net = _VirtualDatasetCell(net)
121    params = net.trainable_params()
122    optimizer = AdamWeightDecay(params)
123    dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
124                      memory_mask, label, input_mask)
125    net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
126    model = Model(net_with_grad)
127
128    model.train(1, dataset, dataset_sink_mode=False)
129
130
131def test_transformer_model():
132    set_auto_parallel_context(device_num=8, global_rank=0,
133                              full_batch=True,
134                              parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
135    net = Transformer(encoder_layers=1,
136                      decoder_layers=2,
137                      batch_size=2,
138                      src_seq_length=20,
139                      tgt_seq_length=10,
140                      hidden_size=64,
141                      num_heads=8,
142                      ffn_hidden_size=64,
143                      parallel_config=config)
144
145    encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
146    encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
147    decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
148    decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
149    memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
150    net = NetWithLossFiveInputs(net)
151    net = _VirtualDatasetCell(net)
152    params = net.trainable_params()
153    optimizer = AdamWeightDecay(params)
154    dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
155                      memory_mask)
156    net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
157    model = Model(net_with_grad)
158
159    model.train(1, dataset, dataset_sink_mode=False)
160
161
162def test_transformer_model_2d_inputs():
163    set_auto_parallel_context(device_num=8, global_rank=0,
164                              full_batch=True,
165                              parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
166    net = Transformer(encoder_layers=1,
167                      decoder_layers=2,
168                      batch_size=2,
169                      src_seq_length=20,
170                      tgt_seq_length=10,
171                      hidden_size=64,
172                      num_heads=8,
173                      ffn_hidden_size=64,
174                      parallel_config=config)
175
176    encoder_input_value = Tensor(np.ones((40, 64)), mstype.float32)
177    encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
178    decoder_input_value = Tensor(np.ones((20, 64)), mstype.float32)
179    decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
180    memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
181    net = NetWithLossFiveInputs(net)
182    net = _VirtualDatasetCell(net)
183    params = net.trainable_params()
184    optimizer = AdamWeightDecay(params)
185    dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
186                      memory_mask)
187    net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
188    model = Model(net_with_grad)
189
190    model.train(1, dataset, dataset_sink_mode=False)
191
192
193def test_transformer_model_int64_inputs():
194    set_auto_parallel_context(device_num=8, global_rank=0,
195                              full_batch=True,
196                              parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
197    net = Transformer(encoder_layers=1,
198                      decoder_layers=2,
199                      batch_size=2,
200                      src_seq_length=20,
201                      tgt_seq_length=10,
202                      hidden_size=64,
203                      num_heads=8,
204                      ffn_hidden_size=64,
205                      parallel_config=config)
206
207    encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.int64)
208    encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
209    decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
210    decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
211    memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
212    net = NetWithLossFiveInputs(net)
213    net = _VirtualDatasetCell(net)
214    params = net.trainable_params()
215    optimizer = AdamWeightDecay(params)
216    dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
217                      memory_mask)
218    net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
219    model = Model(net_with_grad)
220
221    with pytest.raises(TypeError):
222        model.train(1, dataset, dataset_sink_mode=False)
223
224
225def test_transformer_model_head_parallel_only_encoder():
226    local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8)
227    run_total_transformer_model_head(e_layer=2, d_layer=0, arg_parallel_config=local_config)
228
229
230def test_transformer_model_head_parallel():
231    local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8)
232    run_total_transformer_model_head(e_layer=1, d_layer=1, arg_parallel_config=local_config)
233
234
235def test_transformer_model_head_parallel_decoder():
236    local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8)
237    with pytest.raises(ValueError):
238        run_total_transformer_model_head(e_layer=0, d_layer=1, arg_parallel_config=local_config)
239
240
241def test_transformer_model_head_stand_alone():
242    local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=1)
243    run_total_transformer_model_head(e_layer=2, d_layer=2, arg_parallel_config=local_config)
244
245
246def test_transformer_model_auto_parallel_no_support():
247    local_config = TransformerOpParallelConfig(data_parallel=8, model_parallel=1)
248    with pytest.raises(RuntimeError):
249        run_total_transformer_model_head(e_layer=2, d_layer=2, arg_parallel_config=local_config,
250                                         mode=ParallelMode.AUTO_PARALLEL)
251
252
253def test_pipeline_single_transformer():
254    set_auto_parallel_context(device_num=32,
255                              full_batch=True,
256                              pipeline_stages=pipeline_config.pipeline_stage, global_rank=0,
257                              parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
258
259    net = Transformer(batch_size=4 // pipeline_config.micro_batch_num,
260                      src_seq_length=20,
261                      tgt_seq_length=10,
262                      encoder_layers=2,
263                      decoder_layers=2,
264                      hidden_size=64,
265                      num_heads=8,
266                      ffn_hidden_size=64,
267                      parallel_config=pipeline_config)
268
269    encoder_input_value = Tensor(np.ones((4, 20, 64)), mstype.float32)
270    encoder_input_mask = Tensor(np.ones((4, 20, 20)), mstype.float16)
271    decoder_input_value = Tensor(np.ones((4, 10, 64)), mstype.float32)
272    decoder_input_mask = Tensor(np.ones((4, 10, 10)), mstype.float16)
273    memory_mask = Tensor(np.ones((4, 10, 20)), mstype.float16)
274    net = NetWithLossFiveInputs(net)
275    net = PipelineCell(net, pipeline_config.micro_batch_num)
276    net = _VirtualDatasetCell(net)
277    params = net.infer_param_pipeline_stage()
278    optimizer = AdamWeightDecay(params)
279    dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
280                      memory_mask)
281    update_cell = DynamicLossScaleUpdateCell(loss_scale_value=1024, scale_factor=2, scale_window=1000)
282    net_with_grad = _TrainPipelineWithLossScaleCell(net, optimizer=optimizer,
283                                                    scale_sense=update_cell)
284    model = Model(net_with_grad)
285
286    model.train(1, dataset, dataset_sink_mode=False)
287
288
289def test_transformer_wrong_head():
290    set_auto_parallel_context(device_num=32,
291                              full_batch=True,
292                              pipeline_stages=pipeline_config.pipeline_stage, global_rank=0,
293                              parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
294    error_test_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, vocab_emb_dp=False)
295    with pytest.raises(ValueError):
296        net = Transformer(batch_size=4,
297                          src_seq_length=20,
298                          tgt_seq_length=10,
299                          encoder_layers=2,
300                          decoder_layers=2,
301                          hidden_size=64,
302                          num_heads=7,
303                          ffn_hidden_size=64,
304                          parallel_config=error_test_config)
305
306    with pytest.raises(ValueError):
307        net = Transformer(batch_size=4,
308                          src_seq_length=20,
309                          tgt_seq_length=10,
310                          encoder_layers=2,
311                          decoder_layers=2,
312                          hidden_size=63,
313                          num_heads=7,
314                          ffn_hidden_size=64,
315                          parallel_config=error_test_config)
316        del net
317
318
319def test_transformer_wrong_dp_no_error():
320    set_auto_parallel_context(device_num=32, full_batch=False, parallel_mode=ParallelMode.DATA_PARALLEL,
321                              pipeline_stages=pipeline_config.pipeline_stage, global_rank=0)
322    check_config = TransformerOpParallelConfig(data_parallel=8, model_parallel=1, vocab_emb_dp=False)
323    net = Transformer(batch_size=4, src_seq_length=20, tgt_seq_length=10, encoder_layers=2,
324                      decoder_layers=2, hidden_size=64, num_heads=2, ffn_hidden_size=64,
325                      parallel_config=check_config)
326    del net
327
328
329def test_transformer_wrong_semi_auto_dp_error():
330    set_auto_parallel_context(device_num=32, full_batch=False, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
331                              pipeline_stages=pipeline_config.pipeline_stage, global_rank=0)
332    check_config = TransformerOpParallelConfig(data_parallel=16, model_parallel=1, vocab_emb_dp=False)
333    with pytest.raises(ValueError):
334        net = Transformer(batch_size=4, src_seq_length=20, tgt_seq_length=10, encoder_layers=2,
335                          decoder_layers=2, hidden_size=64, num_heads=2, ffn_hidden_size=64,
336                          parallel_config=check_config)
337        del net
338
339
340def test_encoder():
341    class NetWithLoss(nn.Cell):
342        def __init__(self, network):
343            super(NetWithLoss, self).__init__()
344            self.loss = VirtualLoss()
345            self.network = network
346
347        def construct(self, x1, x2):
348            predict, _ = self.network(x1, x2)
349            return self.loss(predict)
350
351    set_auto_parallel_context(device_num=8,
352                              full_batch=True,
353                              global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
354    net = TransformerEncoder(num_layers=2,
355                             batch_size=2,
356                             seq_length=16,
357                             hidden_size=8,
358                             ffn_hidden_size=64,
359                             num_heads=8,
360                             parallel_config=config)
361
362    encoder_input_value = Tensor(np.ones((2, 16, 8)), mstype.float32)
363    encoder_input_mask = Tensor(np.ones((2, 16, 16)), mstype.float16)
364
365    net = NetWithLoss(net)
366
367    net = _VirtualDatasetCell(net)
368
369    dataset = Dataset(encoder_input_value, encoder_input_mask)
370
371    model = Model(net)
372
373    model.train(1, dataset, dataset_sink_mode=False)
374
375
376def test_decoder():
377    class NetWithLoss(nn.Cell):
378        def __init__(self, network):
379            super(NetWithLoss, self).__init__()
380            self.loss = VirtualLoss()
381            self.network = network
382
383        def construct(self, x1, x2, x3, x4):
384            predict, _, _ = self.network(x1, x2, x3, x4)
385            return self.loss(predict)
386
387    set_auto_parallel_context(device_num=8,
388                              full_batch=True,
389                              global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
390    net = TransformerDecoder(num_layers=1,
391                             batch_size=8,
392                             hidden_size=16,
393                             ffn_hidden_size=8,
394                             num_heads=8,
395                             src_seq_length=20,
396                             tgt_seq_length=10,
397                             parallel_config=config)
398
399    encoder_input_value = Tensor(np.ones((8, 20, 16)), mstype.float32)
400    decoder_input_value = Tensor(np.ones((8, 10, 16)), mstype.float32)
401    decoder_input_mask = Tensor(np.ones((8, 10, 10)), mstype.float16)
402    memory_mask = Tensor(np.ones((8, 10, 20)), mstype.float16)
403
404    net = NetWithLoss(net)
405
406    net = _VirtualDatasetCell(net)
407
408    dataset = Dataset(decoder_input_value, decoder_input_mask, encoder_input_value, memory_mask)
409
410    model = Model(net)
411    model.train(1, dataset, dataset_sink_mode=False)
412
413
414def test_vocabembedding_dp_true():
415    set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
416
417    class NetWithLoss(nn.Cell):
418        def __init__(self, network):
419            super(NetWithLoss, self).__init__()
420            self.loss = VirtualLoss()
421            self.network = network
422
423        def construct(self, x1):
424            predict, _ = self.network(x1)
425            return self.loss(predict)
426
427    net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config.embedding_dp_mp_config)
428    net = NetWithLoss(net)
429    net = _VirtualDatasetCell(net)
430    encoder_input_value = Tensor(np.ones((2, 64)), mstype.int32)
431    dataset = Dataset(encoder_input_value)
432
433    model = Model(net)
434    model.train(1, dataset, dataset_sink_mode=False)
435
436
437def test_vocabembedding_dp_false():
438    set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
439
440    class NetWithLoss(nn.Cell):
441        def __init__(self, network):
442            super(NetWithLoss, self).__init__()
443            self.loss = VirtualLoss()
444            self.network = network
445
446        def construct(self, x1):
447            predict, _ = self.network(x1)
448            return self.loss(predict)
449
450    net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config.embedding_dp_mp_config)
451    net = NetWithLoss(net)
452    net = _VirtualDatasetCell(net)
453    encoder_input_value = Tensor(np.ones((2, 64)), mstype.int32)
454    dataset = Dataset(encoder_input_value)
455
456    model = Model(net)
457    model.train(1, dataset, dataset_sink_mode=False)
458
459
460def test_sparse_attention_parallel_mp():
461    set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
462    set_algo_parameters(fully_use_devices=False)
463    sparse_attention_config = OpParallelConfig(model_parallel=8)
464    net = FixedSparseAttention(batch_size=16,
465                               seq_length=1024,
466                               size_per_head=64,
467                               num_heads=8,
468                               block_size=64,
469                               parallel_config=sparse_attention_config)
470    q = Tensor(np.ones((2, 1024, 512)), mstype.float16)
471    k = Tensor(np.ones((2, 1024, 512)), mstype.float16)
472    v = Tensor(np.ones((2, 1024, 512)), mstype.float16)
473    mask = Tensor(np.ones((2, 1024, 1024)), mstype.float32)
474    dataset = Dataset(q, k, v, mask)
475    model = Model(net)
476    model.train(1, dataset, dataset_sink_mode=False)
477
478
479def test_sparse_attention_parallel_mix():
480    set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
481    set_algo_parameters(fully_use_devices=False)
482    sparse_attention_config = OpParallelConfig(data_parallel=2, model_parallel=4)
483    net = FixedSparseAttention(batch_size=16,
484                               seq_length=1024,
485                               size_per_head=64,
486                               num_heads=8,
487                               block_size=64,
488                               parallel_config=sparse_attention_config)
489    q = Tensor(np.ones((2, 1024, 512)), mstype.float16)
490    k = Tensor(np.ones((2, 1024, 512)), mstype.float16)
491    v = Tensor(np.ones((2, 1024, 512)), mstype.float16)
492    mask = Tensor(np.ones((2, 1024, 1024)), mstype.float32)
493    dataset = Dataset(q, k, v, mask)
494    model = Model(net)
495    model.train(1, dataset, dataset_sink_mode=False)
496
497
498def test_sparse_attention_parallel_mix1():
499    set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
500    set_algo_parameters(fully_use_devices=False)
501    sparse_attention_config = OpParallelConfig(data_parallel=4, model_parallel=2)
502    net = FixedSparseAttention(batch_size=16,
503                               seq_length=1024,
504                               size_per_head=64,
505                               num_heads=8,
506                               block_size=64,
507                               parallel_config=sparse_attention_config)
508    q = Tensor(np.ones((2, 1024, 512)), mstype.float16)
509    k = Tensor(np.ones((2, 1024, 512)), mstype.float16)
510    v = Tensor(np.ones((2, 1024, 512)), mstype.float16)
511    mask = Tensor(np.ones((2, 1024, 1024)), mstype.float32)
512    dataset = Dataset(q, k, v, mask)
513    model = Model(net)
514    model.train(1, dataset, dataset_sink_mode=False)
515
516
517def test_sparse_attention_parallel_dp():
518    set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
519    set_algo_parameters(fully_use_devices=False)
520    sparse_attention_config = OpParallelConfig(data_parallel=8, model_parallel=1)
521    net = FixedSparseAttention(batch_size=16,
522                               seq_length=1024,
523                               size_per_head=64,
524                               num_heads=8,
525                               block_size=64,
526                               parallel_config=sparse_attention_config)
527    net = _VirtualDatasetCell(net)
528    q = Tensor(np.ones((2, 1024, 512)), mstype.float16)
529    k = Tensor(np.ones((2, 1024, 512)), mstype.float16)
530    v = Tensor(np.ones((2, 1024, 512)), mstype.float16)
531    mask = Tensor(np.ones((2, 1024, 1024)), mstype.float32)
532    dataset = Dataset(q, k, v, mask)
533    model = Model(net)
534    model.train(1, dataset, dataset_sink_mode=False)
535
536
537def test_parallel_cross_entroy_loss_semi_auto_parallel():
538    set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
539
540    class NetWithLoss(nn.Cell):
541        def __init__(self, network, config_setting):
542            super(NetWithLoss, self).__init__()
543            self.loss = CrossEntropyLoss(config_setting)
544            self.network = network
545
546        def construct(self, x1, x2, x3):
547            predict, _ = self.network(x1)
548            predict = P.Reshape()(predict, (-1, 16))
549            return self.loss(predict, x2, x3)
550
551    net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config.embedding_dp_mp_config)
552    net = NetWithLoss(net, config.dp_mp_config)
553    net = _VirtualDatasetCell(net)
554    embed_ids = Tensor(np.ones((2, 64)), mstype.int32)
555    labels = Tensor(np.ones((2 * 64,)), mstype.int32)
556    input_mask = Tensor(np.ones((2 * 64,)), mstype.float32)
557    dataset = Dataset(embed_ids, labels, input_mask)
558
559    model = Model(net)
560    model.train(1, dataset, dataset_sink_mode=False)
561
562
563def test_transformer_args():
564
565    with pytest.raises(TypeError):
566        Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
567                    tgt_seq_length=20, decoder_layers="aa")
568
569    with pytest.raises(TypeError):
570        Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
571                    tgt_seq_length="a")
572
573    with pytest.raises(TypeError):
574        Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
575                    tgt_seq_length=20, softmax_compute_type=mstype.int64)
576
577    with pytest.raises(TypeError):
578        Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
579                    tgt_seq_length=20, layernorm_compute_type=mstype.int64)
580
581    with pytest.raises(TypeError):
582        Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
583                    tgt_seq_length=20, param_init_type=mstype.int64)
584
585    with pytest.raises(TypeError):
586        Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
587                    tgt_seq_length=20, hidden_dropout_rate=mstype.int64)
588
589    Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
590                tgt_seq_length=20, softmax_compute_type=mstype.float16)
591
592
593def test_transformer_parallel_config():
594    parallel_test_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=3)
595
596    with pytest.raises(TypeError):
597        parallel_test_config.data_parallel = False
598
599    with pytest.raises(ValueError):
600        parallel_test_config.data_parallel = 0
601
602    with pytest.raises(TypeError):
603        parallel_test_config.model_parallel = False
604
605    with pytest.raises(ValueError):
606        parallel_test_config.model_parallel = 0
607
608    with pytest.raises(TypeError):
609        parallel_test_config.pipeline_stage = False
610
611    with pytest.raises(ValueError):
612        parallel_test_config.pipeline_stage = 0
613
614    with pytest.raises(TypeError):
615        parallel_test_config.micro_batch_num = False
616
617    with pytest.raises(ValueError):
618        parallel_test_config.micro_batch_num = 0
619
620    with pytest.raises(TypeError):
621        parallel_test_config.gradient_aggregation_group = False
622
623    with pytest.raises(ValueError):
624        parallel_test_config.gradient_aggregation_group = 0
625
626    with pytest.raises(TypeError):
627        parallel_test_config.recompute = 1
628
629    parallel_test_config.recompute = False
630
631    assert not parallel_test_config.recompute
632
633
634def test_parallel_config():
635    parallel_test_config = OpParallelConfig(data_parallel=1, model_parallel=3)
636
637    with pytest.raises(ValueError):
638        parallel_test_config.data_parallel = 0
639
640    with pytest.raises(TypeError):
641        parallel_test_config.model_parallel = False
642
643    with pytest.raises(ValueError):
644        parallel_test_config.model_parallel = 0
645
646    assert parallel_test_config.model_parallel == 3
647
648
649def test_embedding_parallel_config():
650    parallel_test_config = EmbeddingOpParallelConfig(data_parallel=1, model_parallel=3, vocab_emb_dp=False)
651
652    with pytest.raises(ValueError):
653        parallel_test_config.data_parallel = 0
654
655    with pytest.raises(TypeError):
656        parallel_test_config.model_parallel = False
657
658    with pytest.raises(ValueError):
659        parallel_test_config.model_parallel = 0
660
661    with pytest.raises(TypeError):
662        parallel_test_config.vocab_emb_dp = 0
663
664    assert not parallel_test_config.vocab_emb_dp
665