1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16import pytest 17import mindspore.common.dtype as mstype 18import mindspore.nn as nn 19from mindspore import Tensor 20from mindspore.context import set_auto_parallel_context, ParallelMode 21from mindspore.ops import composite as C 22from mindspore.ops import functional as F 23import mindspore.ops as P 24from mindspore.parallel.nn import TransformerEncoder, TransformerDecoder, Transformer, TransformerOpParallelConfig, \ 25 VocabEmbedding, CrossEntropyLoss, OpParallelConfig, EmbeddingOpParallelConfig, FixedSparseAttention 26from mindspore.nn import Dense as Linear 27from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell 28from mindspore.nn.optim import AdamWeightDecay 29from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell, TrainOneStepCell 30from mindspore.nn.wrap.loss_scale import _TrainPipelineWithLossScaleCell 31from mindspore.train import Model 32from mindspore.parallel import set_algo_parameters 33from tests.dataset_mock import MindData 34from tests.ut.python.ops.test_math_ops import VirtualLoss 35 36grad_all = C.GradOperation(get_all=True) 37 38 39class Dataset(MindData): 40 def __init__(self, *inputs, length=3): 41 super(Dataset, self).__init__(size=length) 42 self.inputs = inputs 43 self.index = 0 44 self.length = length 45 46 def __iter__(self): 47 return self 48 49 def __next__(self): 50 if self.index >= self.length: 51 raise StopIteration 52 self.index += 1 53 return self.inputs 54 55 def reset(self): 56 self.index = 0 57 58 59class TransformerNet(nn.Cell): 60 def __init__(self, en_layer, de_layer, parallel_config): 61 super(TransformerNet, self).__init__() 62 self.embedding = VocabEmbedding(vocab_size=240, embedding_size=20, 63 parallel_config=config.embedding_dp_mp_config) 64 self.network = Transformer(encoder_layers=en_layer, 65 decoder_layers=de_layer, 66 batch_size=2, 67 src_seq_length=20, 68 tgt_seq_length=10, 69 hidden_size=64, 70 num_heads=8, 71 ffn_hidden_size=64, 72 parallel_config=parallel_config) 73 self.head = Linear(in_channels=64, out_channels=200) 74 self.loss = CrossEntropyLoss(parallel_config=config.dp_mp_config) 75 76 def construct(self, x1, x2, x3, x4, x5, y, mask): 77 predict, _, _ = self.network(x1, x2, x3, x4, x5) 78 predict = P.Reshape()(predict, (-1, F.shape(predict)[-1])) 79 return self.loss(predict, y, mask) 80 81config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, vocab_emb_dp=False) 82pipeline_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, pipeline_stage=4, 83 micro_batch_num=4, vocab_emb_dp=False) 84 85 86class NetWithLossFiveInputs(nn.Cell): 87 def __init__(self, network): 88 super(NetWithLossFiveInputs, self).__init__() 89 self.loss = VirtualLoss() 90 self.network = network 91 92 def construct(self, x1, x2, x3, x4, x5): 93 predict, _, _ = self.network(x1, x2, x3, x4, x5) 94 return self.loss(predict) 95 96 97def run_total_transformer_model_head(e_layer, 98 d_layer, 99 arg_parallel_config, 100 mode=ParallelMode.SEMI_AUTO_PARALLEL): 101 dp = arg_parallel_config.data_parallel 102 mp = arg_parallel_config.model_parallel 103 pp = arg_parallel_config.pipeline_stage 104 if dp * mp * pp != 1: 105 set_auto_parallel_context(device_num=8, 106 full_batch=True, 107 global_rank=0, parallel_mode=mode) 108 109 encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32) 110 encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16) 111 decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32) 112 decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16) 113 memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16) 114 seq = 20 115 if d_layer > 0: 116 seq = 10 117 label = Tensor(np.ones((2 * seq,)), mstype.int32) 118 input_mask = Tensor(np.ones((2 * seq,)), mstype.float32) 119 net = TransformerNet(en_layer=e_layer, de_layer=d_layer, parallel_config=arg_parallel_config) 120 net = _VirtualDatasetCell(net) 121 params = net.trainable_params() 122 optimizer = AdamWeightDecay(params) 123 dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask, 124 memory_mask, label, input_mask) 125 net_with_grad = TrainOneStepCell(net, optimizer=optimizer) 126 model = Model(net_with_grad) 127 128 model.train(1, dataset, dataset_sink_mode=False) 129 130 131def test_transformer_model(): 132 set_auto_parallel_context(device_num=8, global_rank=0, 133 full_batch=True, 134 parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL) 135 net = Transformer(encoder_layers=1, 136 decoder_layers=2, 137 batch_size=2, 138 src_seq_length=20, 139 tgt_seq_length=10, 140 hidden_size=64, 141 num_heads=8, 142 ffn_hidden_size=64, 143 parallel_config=config) 144 145 encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32) 146 encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16) 147 decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32) 148 decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16) 149 memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16) 150 net = NetWithLossFiveInputs(net) 151 net = _VirtualDatasetCell(net) 152 params = net.trainable_params() 153 optimizer = AdamWeightDecay(params) 154 dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask, 155 memory_mask) 156 net_with_grad = TrainOneStepCell(net, optimizer=optimizer) 157 model = Model(net_with_grad) 158 159 model.train(1, dataset, dataset_sink_mode=False) 160 161 162def test_transformer_model_2d_inputs(): 163 set_auto_parallel_context(device_num=8, global_rank=0, 164 full_batch=True, 165 parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL) 166 net = Transformer(encoder_layers=1, 167 decoder_layers=2, 168 batch_size=2, 169 src_seq_length=20, 170 tgt_seq_length=10, 171 hidden_size=64, 172 num_heads=8, 173 ffn_hidden_size=64, 174 parallel_config=config) 175 176 encoder_input_value = Tensor(np.ones((40, 64)), mstype.float32) 177 encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16) 178 decoder_input_value = Tensor(np.ones((20, 64)), mstype.float32) 179 decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16) 180 memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16) 181 net = NetWithLossFiveInputs(net) 182 net = _VirtualDatasetCell(net) 183 params = net.trainable_params() 184 optimizer = AdamWeightDecay(params) 185 dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask, 186 memory_mask) 187 net_with_grad = TrainOneStepCell(net, optimizer=optimizer) 188 model = Model(net_with_grad) 189 190 model.train(1, dataset, dataset_sink_mode=False) 191 192 193def test_transformer_model_int64_inputs(): 194 set_auto_parallel_context(device_num=8, global_rank=0, 195 full_batch=True, 196 parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL) 197 net = Transformer(encoder_layers=1, 198 decoder_layers=2, 199 batch_size=2, 200 src_seq_length=20, 201 tgt_seq_length=10, 202 hidden_size=64, 203 num_heads=8, 204 ffn_hidden_size=64, 205 parallel_config=config) 206 207 encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.int64) 208 encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16) 209 decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32) 210 decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16) 211 memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16) 212 net = NetWithLossFiveInputs(net) 213 net = _VirtualDatasetCell(net) 214 params = net.trainable_params() 215 optimizer = AdamWeightDecay(params) 216 dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask, 217 memory_mask) 218 net_with_grad = TrainOneStepCell(net, optimizer=optimizer) 219 model = Model(net_with_grad) 220 221 with pytest.raises(TypeError): 222 model.train(1, dataset, dataset_sink_mode=False) 223 224 225def test_transformer_model_head_parallel_only_encoder(): 226 local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8) 227 run_total_transformer_model_head(e_layer=2, d_layer=0, arg_parallel_config=local_config) 228 229 230def test_transformer_model_head_parallel(): 231 local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8) 232 run_total_transformer_model_head(e_layer=1, d_layer=1, arg_parallel_config=local_config) 233 234 235def test_transformer_model_head_parallel_decoder(): 236 local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8) 237 with pytest.raises(ValueError): 238 run_total_transformer_model_head(e_layer=0, d_layer=1, arg_parallel_config=local_config) 239 240 241def test_transformer_model_head_stand_alone(): 242 local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=1) 243 run_total_transformer_model_head(e_layer=2, d_layer=2, arg_parallel_config=local_config) 244 245 246def test_transformer_model_auto_parallel_no_support(): 247 local_config = TransformerOpParallelConfig(data_parallel=8, model_parallel=1) 248 with pytest.raises(RuntimeError): 249 run_total_transformer_model_head(e_layer=2, d_layer=2, arg_parallel_config=local_config, 250 mode=ParallelMode.AUTO_PARALLEL) 251 252 253def test_pipeline_single_transformer(): 254 set_auto_parallel_context(device_num=32, 255 full_batch=True, 256 pipeline_stages=pipeline_config.pipeline_stage, global_rank=0, 257 parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL) 258 259 net = Transformer(batch_size=4 // pipeline_config.micro_batch_num, 260 src_seq_length=20, 261 tgt_seq_length=10, 262 encoder_layers=2, 263 decoder_layers=2, 264 hidden_size=64, 265 num_heads=8, 266 ffn_hidden_size=64, 267 parallel_config=pipeline_config) 268 269 encoder_input_value = Tensor(np.ones((4, 20, 64)), mstype.float32) 270 encoder_input_mask = Tensor(np.ones((4, 20, 20)), mstype.float16) 271 decoder_input_value = Tensor(np.ones((4, 10, 64)), mstype.float32) 272 decoder_input_mask = Tensor(np.ones((4, 10, 10)), mstype.float16) 273 memory_mask = Tensor(np.ones((4, 10, 20)), mstype.float16) 274 net = NetWithLossFiveInputs(net) 275 net = PipelineCell(net, pipeline_config.micro_batch_num) 276 net = _VirtualDatasetCell(net) 277 params = net.infer_param_pipeline_stage() 278 optimizer = AdamWeightDecay(params) 279 dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask, 280 memory_mask) 281 update_cell = DynamicLossScaleUpdateCell(loss_scale_value=1024, scale_factor=2, scale_window=1000) 282 net_with_grad = _TrainPipelineWithLossScaleCell(net, optimizer=optimizer, 283 scale_sense=update_cell) 284 model = Model(net_with_grad) 285 286 model.train(1, dataset, dataset_sink_mode=False) 287 288 289def test_transformer_wrong_head(): 290 set_auto_parallel_context(device_num=32, 291 full_batch=True, 292 pipeline_stages=pipeline_config.pipeline_stage, global_rank=0, 293 parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL) 294 error_test_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, vocab_emb_dp=False) 295 with pytest.raises(ValueError): 296 net = Transformer(batch_size=4, 297 src_seq_length=20, 298 tgt_seq_length=10, 299 encoder_layers=2, 300 decoder_layers=2, 301 hidden_size=64, 302 num_heads=7, 303 ffn_hidden_size=64, 304 parallel_config=error_test_config) 305 306 with pytest.raises(ValueError): 307 net = Transformer(batch_size=4, 308 src_seq_length=20, 309 tgt_seq_length=10, 310 encoder_layers=2, 311 decoder_layers=2, 312 hidden_size=63, 313 num_heads=7, 314 ffn_hidden_size=64, 315 parallel_config=error_test_config) 316 del net 317 318 319def test_transformer_wrong_dp_no_error(): 320 set_auto_parallel_context(device_num=32, full_batch=False, parallel_mode=ParallelMode.DATA_PARALLEL, 321 pipeline_stages=pipeline_config.pipeline_stage, global_rank=0) 322 check_config = TransformerOpParallelConfig(data_parallel=8, model_parallel=1, vocab_emb_dp=False) 323 net = Transformer(batch_size=4, src_seq_length=20, tgt_seq_length=10, encoder_layers=2, 324 decoder_layers=2, hidden_size=64, num_heads=2, ffn_hidden_size=64, 325 parallel_config=check_config) 326 del net 327 328 329def test_transformer_wrong_semi_auto_dp_error(): 330 set_auto_parallel_context(device_num=32, full_batch=False, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, 331 pipeline_stages=pipeline_config.pipeline_stage, global_rank=0) 332 check_config = TransformerOpParallelConfig(data_parallel=16, model_parallel=1, vocab_emb_dp=False) 333 with pytest.raises(ValueError): 334 net = Transformer(batch_size=4, src_seq_length=20, tgt_seq_length=10, encoder_layers=2, 335 decoder_layers=2, hidden_size=64, num_heads=2, ffn_hidden_size=64, 336 parallel_config=check_config) 337 del net 338 339 340def test_encoder(): 341 class NetWithLoss(nn.Cell): 342 def __init__(self, network): 343 super(NetWithLoss, self).__init__() 344 self.loss = VirtualLoss() 345 self.network = network 346 347 def construct(self, x1, x2): 348 predict, _ = self.network(x1, x2) 349 return self.loss(predict) 350 351 set_auto_parallel_context(device_num=8, 352 full_batch=True, 353 global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL) 354 net = TransformerEncoder(num_layers=2, 355 batch_size=2, 356 seq_length=16, 357 hidden_size=8, 358 ffn_hidden_size=64, 359 num_heads=8, 360 parallel_config=config) 361 362 encoder_input_value = Tensor(np.ones((2, 16, 8)), mstype.float32) 363 encoder_input_mask = Tensor(np.ones((2, 16, 16)), mstype.float16) 364 365 net = NetWithLoss(net) 366 367 net = _VirtualDatasetCell(net) 368 369 dataset = Dataset(encoder_input_value, encoder_input_mask) 370 371 model = Model(net) 372 373 model.train(1, dataset, dataset_sink_mode=False) 374 375 376def test_decoder(): 377 class NetWithLoss(nn.Cell): 378 def __init__(self, network): 379 super(NetWithLoss, self).__init__() 380 self.loss = VirtualLoss() 381 self.network = network 382 383 def construct(self, x1, x2, x3, x4): 384 predict, _, _ = self.network(x1, x2, x3, x4) 385 return self.loss(predict) 386 387 set_auto_parallel_context(device_num=8, 388 full_batch=True, 389 global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL) 390 net = TransformerDecoder(num_layers=1, 391 batch_size=8, 392 hidden_size=16, 393 ffn_hidden_size=8, 394 num_heads=8, 395 src_seq_length=20, 396 tgt_seq_length=10, 397 parallel_config=config) 398 399 encoder_input_value = Tensor(np.ones((8, 20, 16)), mstype.float32) 400 decoder_input_value = Tensor(np.ones((8, 10, 16)), mstype.float32) 401 decoder_input_mask = Tensor(np.ones((8, 10, 10)), mstype.float16) 402 memory_mask = Tensor(np.ones((8, 10, 20)), mstype.float16) 403 404 net = NetWithLoss(net) 405 406 net = _VirtualDatasetCell(net) 407 408 dataset = Dataset(decoder_input_value, decoder_input_mask, encoder_input_value, memory_mask) 409 410 model = Model(net) 411 model.train(1, dataset, dataset_sink_mode=False) 412 413 414def test_vocabembedding_dp_true(): 415 set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL) 416 417 class NetWithLoss(nn.Cell): 418 def __init__(self, network): 419 super(NetWithLoss, self).__init__() 420 self.loss = VirtualLoss() 421 self.network = network 422 423 def construct(self, x1): 424 predict, _ = self.network(x1) 425 return self.loss(predict) 426 427 net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config.embedding_dp_mp_config) 428 net = NetWithLoss(net) 429 net = _VirtualDatasetCell(net) 430 encoder_input_value = Tensor(np.ones((2, 64)), mstype.int32) 431 dataset = Dataset(encoder_input_value) 432 433 model = Model(net) 434 model.train(1, dataset, dataset_sink_mode=False) 435 436 437def test_vocabembedding_dp_false(): 438 set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL) 439 440 class NetWithLoss(nn.Cell): 441 def __init__(self, network): 442 super(NetWithLoss, self).__init__() 443 self.loss = VirtualLoss() 444 self.network = network 445 446 def construct(self, x1): 447 predict, _ = self.network(x1) 448 return self.loss(predict) 449 450 net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config.embedding_dp_mp_config) 451 net = NetWithLoss(net) 452 net = _VirtualDatasetCell(net) 453 encoder_input_value = Tensor(np.ones((2, 64)), mstype.int32) 454 dataset = Dataset(encoder_input_value) 455 456 model = Model(net) 457 model.train(1, dataset, dataset_sink_mode=False) 458 459 460def test_sparse_attention_parallel_mp(): 461 set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL) 462 set_algo_parameters(fully_use_devices=False) 463 sparse_attention_config = OpParallelConfig(model_parallel=8) 464 net = FixedSparseAttention(batch_size=16, 465 seq_length=1024, 466 size_per_head=64, 467 num_heads=8, 468 block_size=64, 469 parallel_config=sparse_attention_config) 470 q = Tensor(np.ones((2, 1024, 512)), mstype.float16) 471 k = Tensor(np.ones((2, 1024, 512)), mstype.float16) 472 v = Tensor(np.ones((2, 1024, 512)), mstype.float16) 473 mask = Tensor(np.ones((2, 1024, 1024)), mstype.float32) 474 dataset = Dataset(q, k, v, mask) 475 model = Model(net) 476 model.train(1, dataset, dataset_sink_mode=False) 477 478 479def test_sparse_attention_parallel_mix(): 480 set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL) 481 set_algo_parameters(fully_use_devices=False) 482 sparse_attention_config = OpParallelConfig(data_parallel=2, model_parallel=4) 483 net = FixedSparseAttention(batch_size=16, 484 seq_length=1024, 485 size_per_head=64, 486 num_heads=8, 487 block_size=64, 488 parallel_config=sparse_attention_config) 489 q = Tensor(np.ones((2, 1024, 512)), mstype.float16) 490 k = Tensor(np.ones((2, 1024, 512)), mstype.float16) 491 v = Tensor(np.ones((2, 1024, 512)), mstype.float16) 492 mask = Tensor(np.ones((2, 1024, 1024)), mstype.float32) 493 dataset = Dataset(q, k, v, mask) 494 model = Model(net) 495 model.train(1, dataset, dataset_sink_mode=False) 496 497 498def test_sparse_attention_parallel_mix1(): 499 set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL) 500 set_algo_parameters(fully_use_devices=False) 501 sparse_attention_config = OpParallelConfig(data_parallel=4, model_parallel=2) 502 net = FixedSparseAttention(batch_size=16, 503 seq_length=1024, 504 size_per_head=64, 505 num_heads=8, 506 block_size=64, 507 parallel_config=sparse_attention_config) 508 q = Tensor(np.ones((2, 1024, 512)), mstype.float16) 509 k = Tensor(np.ones((2, 1024, 512)), mstype.float16) 510 v = Tensor(np.ones((2, 1024, 512)), mstype.float16) 511 mask = Tensor(np.ones((2, 1024, 1024)), mstype.float32) 512 dataset = Dataset(q, k, v, mask) 513 model = Model(net) 514 model.train(1, dataset, dataset_sink_mode=False) 515 516 517def test_sparse_attention_parallel_dp(): 518 set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL) 519 set_algo_parameters(fully_use_devices=False) 520 sparse_attention_config = OpParallelConfig(data_parallel=8, model_parallel=1) 521 net = FixedSparseAttention(batch_size=16, 522 seq_length=1024, 523 size_per_head=64, 524 num_heads=8, 525 block_size=64, 526 parallel_config=sparse_attention_config) 527 net = _VirtualDatasetCell(net) 528 q = Tensor(np.ones((2, 1024, 512)), mstype.float16) 529 k = Tensor(np.ones((2, 1024, 512)), mstype.float16) 530 v = Tensor(np.ones((2, 1024, 512)), mstype.float16) 531 mask = Tensor(np.ones((2, 1024, 1024)), mstype.float32) 532 dataset = Dataset(q, k, v, mask) 533 model = Model(net) 534 model.train(1, dataset, dataset_sink_mode=False) 535 536 537def test_parallel_cross_entroy_loss_semi_auto_parallel(): 538 set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL) 539 540 class NetWithLoss(nn.Cell): 541 def __init__(self, network, config_setting): 542 super(NetWithLoss, self).__init__() 543 self.loss = CrossEntropyLoss(config_setting) 544 self.network = network 545 546 def construct(self, x1, x2, x3): 547 predict, _ = self.network(x1) 548 predict = P.Reshape()(predict, (-1, 16)) 549 return self.loss(predict, x2, x3) 550 551 net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config.embedding_dp_mp_config) 552 net = NetWithLoss(net, config.dp_mp_config) 553 net = _VirtualDatasetCell(net) 554 embed_ids = Tensor(np.ones((2, 64)), mstype.int32) 555 labels = Tensor(np.ones((2 * 64,)), mstype.int32) 556 input_mask = Tensor(np.ones((2 * 64,)), mstype.float32) 557 dataset = Dataset(embed_ids, labels, input_mask) 558 559 model = Model(net) 560 model.train(1, dataset, dataset_sink_mode=False) 561 562 563def test_transformer_args(): 564 565 with pytest.raises(TypeError): 566 Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10, 567 tgt_seq_length=20, decoder_layers="aa") 568 569 with pytest.raises(TypeError): 570 Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10, 571 tgt_seq_length="a") 572 573 with pytest.raises(TypeError): 574 Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10, 575 tgt_seq_length=20, softmax_compute_type=mstype.int64) 576 577 with pytest.raises(TypeError): 578 Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10, 579 tgt_seq_length=20, layernorm_compute_type=mstype.int64) 580 581 with pytest.raises(TypeError): 582 Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10, 583 tgt_seq_length=20, param_init_type=mstype.int64) 584 585 with pytest.raises(TypeError): 586 Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10, 587 tgt_seq_length=20, hidden_dropout_rate=mstype.int64) 588 589 Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10, 590 tgt_seq_length=20, softmax_compute_type=mstype.float16) 591 592 593def test_transformer_parallel_config(): 594 parallel_test_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=3) 595 596 with pytest.raises(TypeError): 597 parallel_test_config.data_parallel = False 598 599 with pytest.raises(ValueError): 600 parallel_test_config.data_parallel = 0 601 602 with pytest.raises(TypeError): 603 parallel_test_config.model_parallel = False 604 605 with pytest.raises(ValueError): 606 parallel_test_config.model_parallel = 0 607 608 with pytest.raises(TypeError): 609 parallel_test_config.pipeline_stage = False 610 611 with pytest.raises(ValueError): 612 parallel_test_config.pipeline_stage = 0 613 614 with pytest.raises(TypeError): 615 parallel_test_config.micro_batch_num = False 616 617 with pytest.raises(ValueError): 618 parallel_test_config.micro_batch_num = 0 619 620 with pytest.raises(TypeError): 621 parallel_test_config.gradient_aggregation_group = False 622 623 with pytest.raises(ValueError): 624 parallel_test_config.gradient_aggregation_group = 0 625 626 with pytest.raises(TypeError): 627 parallel_test_config.recompute = 1 628 629 parallel_test_config.recompute = False 630 631 assert not parallel_test_config.recompute 632 633 634def test_parallel_config(): 635 parallel_test_config = OpParallelConfig(data_parallel=1, model_parallel=3) 636 637 with pytest.raises(ValueError): 638 parallel_test_config.data_parallel = 0 639 640 with pytest.raises(TypeError): 641 parallel_test_config.model_parallel = False 642 643 with pytest.raises(ValueError): 644 parallel_test_config.model_parallel = 0 645 646 assert parallel_test_config.model_parallel == 3 647 648 649def test_embedding_parallel_config(): 650 parallel_test_config = EmbeddingOpParallelConfig(data_parallel=1, model_parallel=3, vocab_emb_dp=False) 651 652 with pytest.raises(ValueError): 653 parallel_test_config.data_parallel = 0 654 655 with pytest.raises(TypeError): 656 parallel_test_config.model_parallel = False 657 658 with pytest.raises(ValueError): 659 parallel_test_config.model_parallel = 0 660 661 with pytest.raises(TypeError): 662 parallel_test_config.vocab_emb_dp = 0 663 664 assert not parallel_test_config.vocab_emb_dp 665