• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numpy as np
16
17import mindspore as ms
18import mindspore.common.dtype as mstype
19import mindspore.nn as nn
20from mindspore import Tensor
21from mindspore import context
22from mindspore.common.api import _cell_graph_executor
23from mindspore.ops import composite as C
24from mindspore.ops import operations as P
25
26
27grad_all = C.GradOperation(get_all=True)
28grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True)
29
30
31class GradWrap(nn.Cell):
32    def __init__(self, network):
33        super(GradWrap, self).__init__()
34        self.network = network
35
36    def construct(self, x, y, b, sens):
37        return grad_all_with_sens(self.network)(x, y, b, sens)
38
39
40class GradWrap2(nn.Cell):
41    def __init__(self, network):
42        super(GradWrap2, self).__init__()
43        self.network = network
44
45    def construct(self, x, y, b):
46        loss = self.network(x, y, b)
47        sens = P.Fill()(mstype.float32, P.Shape()(loss), 1.0)
48        return grad_all_with_sens(self.network)(x, y, b, sens)
49
50
51class GradWrap3(nn.Cell):
52    def __init__(self, network):
53        super(GradWrap3, self).__init__()
54        self.network = network
55
56    def construct(self, x, y, bias):
57        return grad_all(self.network)(x, y, bias)
58
59class GradWrap4(nn.Cell):
60    def __init__(self, network):
61        super(GradWrap4, self).__init__()
62        self.network = network
63
64    def construct(self, x, y):
65        return grad_all(self.network)(x, y)
66
67def compile_net(net, x, y, b):
68    net.set_auto_parallel()
69    net.set_train()
70    _cell_graph_executor.compile(net, x, y, b)
71
72def compile_net_no_bias(net, x, y):
73    net.set_auto_parallel()
74    net.set_train()
75    _cell_graph_executor.compile(net, x, y)
76
77def test_no_grad():
78    class Net(nn.Cell):
79        def __init__(self, strategy1, strategy2):
80            super().__init__()
81            self.matmul1 = P.MatMul().shard(strategy1)
82            self.matmul2 = P.MatMul().shard(strategy2)
83
84        def construct(self, x, y, b):
85            out = self.matmul1(x, y)
86            out = self.matmul2(out, b)
87            return out
88
89    context.set_auto_parallel_context(device_num=8, global_rank=0)
90
91    strategy1 = ((4, 2), (2, 1))
92    strategy2 = ((2, 4), (4, 1))
93    net = Net(strategy1, strategy2)
94    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
95
96    x = Tensor(np.ones([128, 32]), dtype=ms.float32)
97    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
98    b = Tensor(np.ones([64, 64]), dtype=ms.float32)
99    compile_net(net, x, y, b)
100
101
102def test_grad_sens_parameter_type():
103    class Net(nn.Cell):
104        def __init__(self, strategy1, strategy2):
105            super().__init__()
106            self.matmul1 = P.MatMul().shard(strategy1)
107            self.matmul2 = P.MatMul().shard(strategy2)
108
109        def construct(self, x, y, b):
110            out = self.matmul1(x, y)
111            out = self.matmul2(out, b)
112            return out
113
114    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=64, global_rank=0)
115    strategy1 = ((8, 1), (1, 8))
116    strategy2 = ((8, 8), (8, 1))
117    net = GradWrap(Net(strategy1, strategy2))
118
119    x = Tensor(np.ones([128, 32]), dtype=ms.float32)
120    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
121    b = Tensor(np.ones([64, 64]), dtype=ms.float32)
122
123    sens = Tensor(np.ones([128, 64]), dtype=ms.float32)
124    net.set_auto_parallel()
125    net.set_train()
126    _cell_graph_executor.compile(net, x, y, b, sens, phase='train', auto_parallel_mode=True)
127    x_layout = ([8, 8], [1, -1], [16, 32], 0, True, '')
128    y_layout = ([8, 8], [-1, 0], [32, 8], 0, True, '')
129    b_layout = ([8, 8], [0, -1], [8, 64], 0, True, '')
130    sens_layout = ([8, 8], [1, -1], [16, 64], 0, True, '')
131    expect_dict = {'x': x_layout, 'y': y_layout, 'b': b_layout, 'sens': sens_layout}
132    assert net.parameter_layout_dict == expect_dict
133
134
135def test_grad_sens_tensor_type():
136    class Net(nn.Cell):
137        def __init__(self, strategy1, strategy2):
138            super().__init__()
139            self.matmul1 = P.MatMul().shard(strategy1)
140            self.matmul2 = P.MatMul().shard(strategy2)
141
142        def construct(self, x, y, b):
143            out = self.matmul1(x, y)
144            out = self.matmul2(out, b)
145            return out
146
147    context.set_auto_parallel_context(device_num=8, global_rank=0)
148
149    strategy1 = ((4, 2), (2, 1))
150    strategy2 = ((2, 4), (4, 1))
151    net = GradWrap2(Net(strategy1, strategy2))
152    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
153
154    x = Tensor(np.ones([128, 32]), dtype=ms.float32)
155    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
156    b = Tensor(np.ones([64, 64]), dtype=ms.float32)
157    compile_net(net, x, y, b)
158
159
160def test_grad_sens_scalar_broadcast():
161    class Net(nn.Cell):
162        def __init__(self, strategy0, strategy1):
163            super().__init__()
164            self.fc_nobias = P.MatMul(transpose_b=True).shard(strategy0)
165            self.reduce_sum = P.ReduceSum(keep_dims=False).shard(strategy1)
166
167        def construct(self, x, y):
168            out = self.fc_nobias(x, y)
169            out = self.reduce_sum(out, (0, 1))
170            return out
171
172    context.set_auto_parallel_context(device_num=16, global_rank=0)
173    strategy0 = ((4, 1), (4, 1))
174    strategy1 = ((4, 1),)
175    net = GradWrap4(Net(strategy0, strategy1))
176    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
177
178    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
179    y = Tensor(np.ones([64, 32]), dtype=ms.float32)
180    compile_net_no_bias(net, x, y)
181