• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import re
16import numpy as np
17
18import mindspore as ms
19import mindspore.nn as nn
20from mindspore import Tensor
21from mindspore import context
22from mindspore.common.api import _cell_graph_executor
23from mindspore.ops import composite as C
24from mindspore.ops import operations as P
25from mindspore.parallel._utils import _reset_op_id as reset_op_id
26from tests.ut.python.ops.test_math_ops import VirtualLoss
27
28context.set_context(mode=context.GRAPH_MODE)
29
30
31grad_all = C.GradOperation(get_all=True)
32
33
34class NetWithLoss(nn.Cell):
35    def __init__(self, network):
36        super(NetWithLoss, self).__init__()
37        self.loss = VirtualLoss()
38        self.network = network
39
40    def construct(self, x, y, b):
41        predict = self.network(x, y, b)
42        return self.loss(predict)
43
44
45class GradWrap(nn.Cell):
46    def __init__(self, network):
47        super(GradWrap, self).__init__()
48        self.network = network
49
50    def construct(self, x, y, b):
51        return grad_all(self.network)(x, y, b)
52
53
54def compile_net(net, x, y, b, phase):
55    net.set_auto_parallel()
56    net.set_train()
57    _cell_graph_executor.compile(net, x, y, b, phase=phase)
58
59
60def test_auto_parallel_arithmetic():
61    class Net(nn.Cell):
62        def __init__(self):
63            super().__init__()
64            self.matmul = P.MatMul()
65            self.floordiv = P.FloorDiv()
66
67        def construct(self, x, y, b):
68            out = self.matmul(x, y)
69            out = self.floordiv(out, b)
70            return out
71
72    context.set_auto_parallel_context(device_num=8, global_rank=0)
73    net = NetWithLoss(Net())
74    context.set_auto_parallel_context(parallel_mode="auto_parallel")
75    reset_op_id()
76
77    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
78    y = Tensor(np.ones([32, 128]), dtype=ms.float32)
79    b = Tensor(np.ones([64, 128]), dtype=ms.float32)
80    compile_net(net, x, y, b, phase='train')
81    strategies = _cell_graph_executor._get_shard_strategy(net)
82    for (k, v) in strategies.items():
83        if re.search('FloorDiv-op', k) is not None:
84            assert v == [[2, 4], [2, 4]]
85        elif re.search('MatMul-op', k) is not None:
86            assert v == [[2, 1], [1, 4]]
87
88
89def test_auto_parallel_arithmetic_broadcast_both():
90    class Net(nn.Cell):
91        def __init__(self):
92            super().__init__()
93            self.matmul = P.MatMul()
94            self.floordiv = P.FloorDiv()
95
96        def construct(self, x, y, b):
97            out = self.matmul(x, y)
98            out = self.floordiv(out, b)
99            return out
100
101    context.set_auto_parallel_context(device_num=8, global_rank=0)
102    net = NetWithLoss(Net())
103    context.set_auto_parallel_context(parallel_mode="auto_parallel")
104    reset_op_id()
105
106    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
107    y = Tensor(np.ones([32, 1]), dtype=ms.float32)
108    b = Tensor(np.ones([1, 64]), dtype=ms.float32)
109    compile_net(net, x, y, b, phase='train')
110    strategies = _cell_graph_executor._get_shard_strategy(net)
111    for (k, v) in strategies.items():
112        if re.search('FloorDiv-op', k) is not None:
113            assert v == [[8, 1], [1, 1]]
114        elif re.search('MatMul-op', k) is not None:
115            assert v == [[8, 1], [1, 1]]
116
117
118def test_auto_parallel_arithmetic_broadcast_right():
119    class Net(nn.Cell):
120        def __init__(self):
121            super().__init__()
122            self.matmul = P.MatMul()
123            self.floordiv = P.FloorDiv()
124
125        def construct(self, x, y, b):
126            out = self.matmul(x, y)
127            out = self.floordiv(out, b)
128            return out
129
130    context.set_auto_parallel_context(device_num=8, global_rank=0)
131    net = NetWithLoss(Net())
132    context.set_auto_parallel_context(parallel_mode="auto_parallel")
133    reset_op_id()
134
135    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
136    y = Tensor(np.ones([32, 32]), dtype=ms.float32)
137    b = Tensor(np.ones([32]), dtype=ms.float32)
138    compile_net(net, x, y, b, phase='train')
139    strategies = _cell_graph_executor._get_shard_strategy(net)
140    for (k, v) in strategies.items():
141        if re.search('FloorDiv-op', k) is not None:
142            assert v == [[4, 2], [2]]
143        elif re.search('MatMul-op', k) is not None:
144            assert v == [[4, 1], [1, 2]]
145
146
147def test_auto_parallel_arithmetic_broadcast_left():
148    class Net(nn.Cell):
149        def __init__(self):
150            super().__init__()
151            self.matmul = P.MatMul()
152            self.floordiv = P.FloorDiv()
153
154        def construct(self, x, y, b):
155            out = self.matmul(x, y)
156            out = self.floordiv(out, b)
157            return out
158
159    context.set_auto_parallel_context(device_num=8, global_rank=0)
160    net = NetWithLoss(Net())
161    context.set_auto_parallel_context(parallel_mode="auto_parallel")
162    reset_op_id()
163
164    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
165    y = Tensor(np.ones([32, 32]), dtype=ms.float32)
166    b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
167    compile_net(net, x, y, b, phase="train")
168    strategies = _cell_graph_executor._get_shard_strategy(net)
169    for (k, v) in strategies.items():
170        if re.search('FloorDiv-op', k) is not None:
171            assert v == [[4, 2], [1, 4, 2]]
172        elif re.search('MatMul-op', k) is not None:
173            assert v == [[4, 1], [1, 2]]
174