1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16 17import mindspore.nn as nn 18import mindspore as ms 19from mindspore import Tensor, context, Parameter 20from mindspore.common.api import _cell_graph_executor 21from mindspore.ops import operations as P 22from mindspore.common.initializer import initializer 23from mindspore.context import _Context 24from ....train_step_wrap import train_step_with_loss_warp 25 26class MatMulCell(nn.Cell): 27 def __init__(self): 28 super(MatMulCell, self).__init__() 29 self.reshape = P.Reshape() 30 self.matmul0 = P.MatMul() 31 self.weight = Parameter(initializer("ones", [128, 64], ms.float32), name="weight") 32 self.relu = P.ReLU().shard(((1, 8),)) 33 def construct(self, x): 34 x = self.matmul0(x, self.weight) 35 x = self.reshape(x, (32, 128)) 36 x = self.relu(x) 37 return x 38 39class DenseMutMulNet(nn.Cell): 40 def __init__(self): 41 super(DenseMutMulNet, self).__init__() 42 self.fc1 = nn.Dense(128, 768, activation='relu') 43 self.fc2 = nn.Dense(128, 768, activation='relu') 44 self.fc3 = nn.Dense(128, 768, activation='relu') 45 self.fc4 = nn.Dense(768, 768, activation='relu') 46 self.fc1.matmul.shard(((1, 1), (1, 8))) 47 self.fc2.matmul.shard(((1, 1), (1, 8))) 48 self.fc3.matmul.shard(((1, 1), (1, 8))) 49 self.relu4 = nn.ReLU() 50 self.relu5 = nn.ReLU() 51 self.transpose = P.Transpose() 52 self.matmul1 = P.MatMul() 53 self.matmul2 = P.MatMul() 54 self.matmul_cell = MatMulCell() 55 self.fc1.recompute(mp_comm_recompute=False) 56 self.fc2.recompute(mp_comm_recompute=False) 57 self.fc3.recompute(mp_comm_recompute=False) 58 self.matmul_cell.recompute(mp_comm_recompute=False) 59 60 def construct(self, x): 61 x = self.matmul_cell(x) 62 q = self.fc1(x) 63 k = self.fc2(x) 64 v = self.fc3(x) 65 k = self.transpose(k, (1, 0)) 66 c = self.relu4(self.matmul1(q, k)) 67 s = self.relu5(self.matmul2(c, v)) 68 s = self.fc4(s) 69 return s 70 71 72def test_dmnet_train_step(): 73 context.reset_auto_parallel_context() 74 _Context().set_backend_policy("vm") 75 context.set_context(mode=context.GRAPH_MODE) 76 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8) 77 input_ = Tensor(np.ones([64, 128]).astype(np.float32) * 0.01) 78 label = Tensor(np.zeros([32, 768]).astype(np.float32)) 79 net = train_step_with_loss_warp(DenseMutMulNet()) 80 net.set_auto_parallel() 81 net.set_train() 82 _cell_graph_executor.compile(net, input_, label) 83 _Context().set_backend_policy("ge") 84