• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import numpy as np
16
17import mindspore as ms
18import mindspore.nn as nn
19from mindspore import Tensor
20from mindspore import context
21from mindspore.common.api import _cell_graph_executor
22from mindspore.ops import composite as C
23from mindspore.ops import operations as P
24from tests.ut.python.ops.test_math_ops import VirtualLoss
25
26
27grad_all = C.GradOperation(get_all=True)
28
29
30class NetWithLoss(nn.Cell):
31    def __init__(self, network):
32        super(NetWithLoss, self).__init__()
33        self.loss = VirtualLoss()
34        self.network = network
35
36    def construct(self, x, y):
37        predict = self.network(x, y)
38        return self.loss(predict)
39
40
41class GradWrap(nn.Cell):
42    def __init__(self, network):
43        super(GradWrap, self).__init__()
44        self.network = network
45
46    def construct(self, x, y):
47        return grad_all(self.network)(x, y)
48
49
50class Net(nn.Cell):
51    def __init__(self, axis=0, strategy1=None, strategy2=None, shape=None, target=""):
52        super().__init__()
53        if shape is None:
54            shape = [64, 64]
55        self.gatherv2 = P.SparseGatherV2().shard(strategy1).add_prim_attr("primitive_target", target)
56        self.mul = P.Mul().shard(strategy2)
57        self.index = Tensor(np.ones(shape), dtype=ms.int32)
58        self.axis = axis
59
60    def construct(self, x, y):
61        out = self.gatherv2(x, self.index, self.axis)
62        out = self.mul(out, y)
63        return out
64
65
66def test_gatherv2_semi_auto0():
67    context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
68    strategy1 = ((1, 8), (1, 1))
69    strategy2 = ((4, 2, 1), (4, 2, 1))
70    net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2)))
71    net.set_auto_parallel()
72
73    x = Tensor(np.ones([64, 64]), dtype=ms.float32)
74    y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
75    net.set_train()
76    _cell_graph_executor.compile(net, x, y)
77
78
79def test_gatherv2_semi_auto1():
80    context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
81    strategy1 = ((8, 1), (1, 1))
82    strategy2 = ((4, 2, 1), (4, 2, 1))
83    net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2)))
84    net.set_auto_parallel()
85
86    x = Tensor(np.ones([64, 64]), dtype=ms.float32)
87    y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
88    net.set_train()
89    _cell_graph_executor.compile(net, x, y)
90
91
92def test_gatherv2_semi_auto2():
93    context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
94    strategy1 = ((2, 4), (1, 1))
95    strategy2 = ((4, 2, 1), (4, 2, 1))
96    net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2)))
97    net.set_auto_parallel()
98
99    x = Tensor(np.ones([64, 64]), dtype=ms.float32)
100    y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
101    net.set_train()
102    _cell_graph_executor.compile(net, x, y)
103
104
105def test_gatherv2_semi_auto3():
106    context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
107    strategy1 = ((1, 8), (1, 1))
108    strategy2 = ((4, 2, 1), (4, 2, 1))
109    net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2)))
110    net.set_auto_parallel()
111
112    x = Tensor(np.ones([64, 64]), dtype=ms.float32)
113    y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
114    net.set_train()
115    _cell_graph_executor.compile(net, x, y)
116
117
118def test_gatherv2_semi_auto4():
119    context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
120    strategy1 = ((8, 1), (1, 1))
121    strategy2 = ((4, 2, 1), (4, 2, 1))
122    net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2)))
123    net.set_auto_parallel()
124
125    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
126    y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
127    net.set_train()
128    _cell_graph_executor.compile(net, x, y)
129
130
131def test_gatherv2_semi_auto5():
132    context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
133    strategy1 = ((2, 4), (1, 1))
134    strategy2 = ((4, 2, 1), (4, 2, 1))
135    net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2)))
136    net.set_auto_parallel()
137
138    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
139    y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
140    net.set_train()
141    _cell_graph_executor.compile(net, x, y)
142
143
144def test_gatherv2_semi_auto6():
145    context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
146    strategy2 = ((4, 2, 1), (4, 2, 1))
147    net = GradWrap(NetWithLoss(Net(0, None, strategy2)))
148    net.set_auto_parallel()
149
150    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
151    y = Tensor(np.ones([64, 64, 32]), dtype=ms.float32)
152    net.set_train()
153    _cell_graph_executor.compile(net, x, y)
154
155
156def test_gatherv2_semi_auto7():
157    context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
158    strategy2 = ((4, 2, 1), (4, 2, 1))
159    net = GradWrap(NetWithLoss(Net(1, None, strategy2)))
160    net.set_auto_parallel()
161
162    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
163    y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
164    net.set_train()
165    _cell_graph_executor.compile(net, x, y)
166
167
168def test_gatherv2_auto0():
169    context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel")
170    net = GradWrap(NetWithLoss(Net(0)))
171    net.set_auto_parallel()
172    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
173    y = Tensor(np.ones([64, 64, 32]), dtype=ms.float32)
174    net.set_train()
175    _cell_graph_executor.compile(net, x, y)
176
177
178def test_gatherv2_auto1():
179    context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel")
180    net = GradWrap(NetWithLoss(Net(1)))
181    net.set_auto_parallel()
182    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
183    y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
184    net.set_train()
185    _cell_graph_executor.compile(net, x, y)
186
187
188def test_gatherv2_cpu0():
189    context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
190    strategy1 = ((8, 1), (1, 1))
191    strategy2 = ((4, 2, 1), (4, 2, 1))
192    net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
193    net.set_auto_parallel()
194
195    x = Tensor(np.ones([64, 64]), dtype=ms.float32)
196    y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
197    net.set_train()
198    _cell_graph_executor.compile(net, x, y)
199
200
201def test_gatherv2_cpu1():
202    context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel")
203    strategy1 = ((16, 1), (1, 1))
204    strategy2 = ((4, 2, 1), (4, 2, 1))
205    net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
206    net.set_auto_parallel()
207
208    x = Tensor(np.ones([64, 64]), dtype=ms.float32)
209    y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
210    net.set_train()
211    _cell_graph_executor.compile(net, x, y)
212
213
214def test_gatherv2_cpu2():
215    context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
216    strategy1 = ((1, 8), (1, 1))
217    strategy2 = ((4, 2, 1), (4, 2, 1))
218    net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
219    net.set_auto_parallel()
220
221    x = Tensor(np.ones([64, 64]), dtype=ms.float32)
222    y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
223    net.set_train()
224    _cell_graph_executor.compile(net, x, y)
225