1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16 17import mindspore as ms 18import mindspore.common.dtype as mstype 19import mindspore.nn as nn 20from mindspore import Tensor 21from mindspore import context 22from mindspore.common.api import _cell_graph_executor 23from mindspore.ops import composite as C 24from mindspore.ops import operations as P 25 26 27grad_all = C.GradOperation(get_all=True) 28grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True) 29 30 31class GradWrap(nn.Cell): 32 def __init__(self, network): 33 super(GradWrap, self).__init__() 34 self.network = network 35 36 def construct(self, x, y, b, sens): 37 return grad_all_with_sens(self.network)(x, y, b, sens) 38 39 40class GradWrap2(nn.Cell): 41 def __init__(self, network): 42 super(GradWrap2, self).__init__() 43 self.network = network 44 45 def construct(self, x, y, b): 46 loss = self.network(x, y, b) 47 sens = P.Fill()(mstype.float32, P.Shape()(loss), 1.0) 48 return grad_all_with_sens(self.network)(x, y, b, sens) 49 50 51class GradWrap3(nn.Cell): 52 def __init__(self, network): 53 super(GradWrap3, self).__init__() 54 self.network = network 55 56 def construct(self, x, y, bias): 57 return grad_all(self.network)(x, y, bias) 58 59class GradWrap4(nn.Cell): 60 def __init__(self, network): 61 super(GradWrap4, self).__init__() 62 self.network = network 63 64 def construct(self, x, y): 65 return grad_all(self.network)(x, y) 66 67def compile_net(net, x, y, b): 68 net.set_auto_parallel() 69 net.set_train() 70 _cell_graph_executor.compile(net, x, y, b) 71 72def compile_net_no_bias(net, x, y): 73 net.set_auto_parallel() 74 net.set_train() 75 _cell_graph_executor.compile(net, x, y) 76 77def test_no_grad(): 78 class Net(nn.Cell): 79 def __init__(self, strategy1, strategy2): 80 super().__init__() 81 self.matmul1 = P.MatMul().shard(strategy1) 82 self.matmul2 = P.MatMul().shard(strategy2) 83 84 def construct(self, x, y, b): 85 out = self.matmul1(x, y) 86 out = self.matmul2(out, b) 87 return out 88 89 context.set_auto_parallel_context(device_num=8, global_rank=0) 90 91 strategy1 = ((4, 2), (2, 1)) 92 strategy2 = ((2, 4), (4, 1)) 93 net = Net(strategy1, strategy2) 94 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 95 96 x = Tensor(np.ones([128, 32]), dtype=ms.float32) 97 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 98 b = Tensor(np.ones([64, 64]), dtype=ms.float32) 99 compile_net(net, x, y, b) 100 101 102def test_grad_sens_parameter_type(): 103 class Net(nn.Cell): 104 def __init__(self, strategy1, strategy2): 105 super().__init__() 106 self.matmul1 = P.MatMul().shard(strategy1) 107 self.matmul2 = P.MatMul().shard(strategy2) 108 109 def construct(self, x, y, b): 110 out = self.matmul1(x, y) 111 out = self.matmul2(out, b) 112 return out 113 114 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=64, global_rank=0) 115 strategy1 = ((8, 1), (1, 8)) 116 strategy2 = ((8, 8), (8, 1)) 117 net = GradWrap(Net(strategy1, strategy2)) 118 119 x = Tensor(np.ones([128, 32]), dtype=ms.float32) 120 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 121 b = Tensor(np.ones([64, 64]), dtype=ms.float32) 122 123 sens = Tensor(np.ones([128, 64]), dtype=ms.float32) 124 net.set_auto_parallel() 125 net.set_train() 126 _cell_graph_executor.compile(net, x, y, b, sens, phase='train', auto_parallel_mode=True) 127 x_layout = ([8, 8], [1, -1], [16, 32], 0, True, '') 128 y_layout = ([8, 8], [-1, 0], [32, 8], 0, True, '') 129 b_layout = ([8, 8], [0, -1], [8, 64], 0, True, '') 130 sens_layout = ([8, 8], [1, -1], [16, 64], 0, True, '') 131 expect_dict = {'x': x_layout, 'y': y_layout, 'b': b_layout, 'sens': sens_layout} 132 assert net.parameter_layout_dict == expect_dict 133 134 135def test_grad_sens_tensor_type(): 136 class Net(nn.Cell): 137 def __init__(self, strategy1, strategy2): 138 super().__init__() 139 self.matmul1 = P.MatMul().shard(strategy1) 140 self.matmul2 = P.MatMul().shard(strategy2) 141 142 def construct(self, x, y, b): 143 out = self.matmul1(x, y) 144 out = self.matmul2(out, b) 145 return out 146 147 context.set_auto_parallel_context(device_num=8, global_rank=0) 148 149 strategy1 = ((4, 2), (2, 1)) 150 strategy2 = ((2, 4), (4, 1)) 151 net = GradWrap2(Net(strategy1, strategy2)) 152 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 153 154 x = Tensor(np.ones([128, 32]), dtype=ms.float32) 155 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 156 b = Tensor(np.ones([64, 64]), dtype=ms.float32) 157 compile_net(net, x, y, b) 158 159 160def test_grad_sens_scalar_broadcast(): 161 class Net(nn.Cell): 162 def __init__(self, strategy0, strategy1): 163 super().__init__() 164 self.fc_nobias = P.MatMul(transpose_b=True).shard(strategy0) 165 self.reduce_sum = P.ReduceSum(keep_dims=False).shard(strategy1) 166 167 def construct(self, x, y): 168 out = self.fc_nobias(x, y) 169 out = self.reduce_sum(out, (0, 1)) 170 return out 171 172 context.set_auto_parallel_context(device_num=16, global_rank=0) 173 strategy0 = ((4, 1), (4, 1)) 174 strategy1 = ((4, 1),) 175 net = GradWrap4(Net(strategy0, strategy1)) 176 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 177 178 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 179 y = Tensor(np.ones([64, 32]), dtype=ms.float32) 180 compile_net_no_bias(net, x, y) 181