1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16 17import mindspore as ms 18import mindspore.nn as nn 19from mindspore import Tensor 20from mindspore import context 21from mindspore.common.api import _cell_graph_executor 22from mindspore.context import set_auto_parallel_context 23from mindspore.ops import composite as C 24from mindspore.ops import operations as P 25from mindspore.common.initializer import initializer 26from mindspore.common.parameter import Parameter 27from tests.ut.python.ops.test_math_ops import VirtualLoss 28 29 30grad_all = C.GradOperation(get_all=True) 31 32 33class NetWithLoss(nn.Cell): 34 def __init__(self, network): 35 super(NetWithLoss, self).__init__() 36 self.loss = VirtualLoss() 37 self.network = network 38 39 def construct(self, x): 40 predict = self.network(x) 41 return self.loss(predict) 42 43 44class GradWrap(nn.Cell): 45 def __init__(self, network): 46 super(GradWrap, self).__init__() 47 self.network = network 48 49 def construct(self, x): 50 return grad_all(self.network)(x) 51 52 53def compile_net(net, x): 54 net.set_auto_parallel() 55 net.set_train() 56 _cell_graph_executor.compile(net, x) 57 58 59class Net(nn.Cell): 60 def __init__(self, strategy1, strategy2, strategy3, strategy4, strategy5): 61 super().__init__() 62 self.query_w = Parameter(initializer( 63 "normal", [8, 16], ms.float32), name='query') 64 self.query = P.MatMul().shard(strategy1) 65 66 self.key_w = Parameter(initializer( 67 "normal", [8, 16], ms.float32), name='key') 68 self.key = P.MatMul().shard(strategy2) 69 70 self.value_w = Parameter(initializer( 71 "normal", [8, 16], ms.float32), name='value') 72 self.value = P.MatMul().shard(strategy3) 73 74 self.score = P.MatMul().shard(strategy4) 75 self.context = P.MatMul().shard(strategy5) 76 self.transpose1 = P.Transpose() 77 self.transpose2 = P.Transpose() 78 self.relu = P.ReLU() 79 80 def construct(self, x): 81 q = self.query(x, self.query_w) 82 k = self.key(x, self.key_w) 83 v = self.value(x, self.value_w) 84 85 k = self.transpose1(k, (1, 0)) 86 s = self.score(q, k) 87 88 v = self.transpose2(v, (1, 0)) 89 c = self.context(v, s) 90 out = self.relu(c) 91 92 return out 93 94 95def test_self_attention_standalone(): 96 set_auto_parallel_context(device_num=8, global_rank=0) 97 context.set_auto_parallel_context(parallel_mode="stand_alone") 98 net = GradWrap(NetWithLoss( 99 Net(None, None, None, None, None))) 100 101 x = Tensor(np.ones([32, 8]), dtype=ms.float32) 102 103 compile_net(net, x) 104 105 106def test_self_attention_semi(): 107 set_auto_parallel_context(device_num=8, global_rank=0) 108 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 109 110 strategy1 = ((2, 2), (2, 2)) 111 strategy2 = ((2, 2), (2, 2)) 112 strategy3 = ((2, 2), (2, 2)) 113 strategy4 = ((2, 4), (4, 1)) 114 strategy5 = ((2, 1), (1, 4)) 115 116 net = GradWrap(NetWithLoss( 117 Net(strategy1, strategy2, strategy3, strategy4, strategy5))) 118 119 x = Tensor(np.ones([32, 8]), dtype=ms.float32) 120 121 compile_net(net, x) 122 123 124def test_self_attention_dp(): 125 set_auto_parallel_context(device_num=8, global_rank=0) 126 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 127 128 strategy1 = ((8, 1), (1, 1)) 129 strategy2 = ((8, 1), (1, 1)) 130 strategy3 = ((8, 1), (1, 1)) 131 strategy4 = ((8, 1), (1, 1)) 132 strategy5 = ((8, 1), (1, 1)) 133 134 net = GradWrap(NetWithLoss( 135 Net(strategy1, strategy2, strategy3, strategy4, strategy5))) 136 137 x = Tensor(np.ones([32, 8]), dtype=ms.float32) 138 139 compile_net(net, x) 140 141 142def test_self_attention_auto(): 143 set_auto_parallel_context(device_num=8, global_rank=0) 144 context.set_auto_parallel_context(parallel_mode="auto_parallel") 145 net = GradWrap(NetWithLoss( 146 Net(None, None, None, None, None))) 147 148 x = Tensor(np.ones([32, 8]), dtype=ms.float32) 149 150 compile_net(net, x) 151