• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numpy as np
16
17import mindspore as ms
18import mindspore.nn as nn
19from mindspore import Tensor
20from mindspore import context
21from mindspore.common.api import _cell_graph_executor
22from mindspore.ops import composite as C
23from mindspore.ops import operations as P
24from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
25
26context.set_context(mode=context.GRAPH_MODE)
27
28
29grad_all = C.GradOperation(get_all=True)
30
31
32class NetWithLoss(nn.Cell):
33    def __init__(self, network, strategy3, strategy4, axis):
34        super(NetWithLoss, self).__init__()
35        self.one_hot = P.OneHot(axis=axis).shard(strategy3)
36        self.on_value = Tensor(2.0, ms.float32)
37        self.off_value = Tensor(1.0, ms.float32)
38        self.loss = P.SoftmaxCrossEntropyWithLogits().shard(strategy4)
39        self.network = network
40
41    def construct(self, x, y, b):
42        predict = self.network(x, y)
43        label = self.one_hot(b, 64, self.on_value, self.off_value)
44        return self.loss(predict, label)[0]
45
46
47class GradWrap(nn.Cell):
48    def __init__(self, network):
49        super(GradWrap, self).__init__()
50        self.network = network
51
52    def construct(self, x, y, b):
53        return grad_all(self.network)(x, y, b)
54
55
56class Net(nn.Cell):
57    def __init__(self, strategy1, strategy2):
58        super().__init__()
59        self.matmul = P.MatMul().shard(strategy1)
60        self.gelu = P.GeLU().shard(strategy2)
61
62    def construct(self, x, y):
63        out = self.matmul(x, y)
64        out = self.gelu(out)
65        return out
66
67
68def compile_graph(strategy1, strategy2, strategy3, strategy4, auto=False, onthot_axis=-1):
69    net = GradWrap(_VirtualDatasetCell(NetWithLoss(Net(strategy1, strategy2), strategy3, strategy4, axis=onthot_axis)))
70    net.set_auto_parallel()
71    if auto:
72        context.set_auto_parallel_context(parallel_mode="auto_parallel")
73    else:
74        context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
75
76    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
77    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
78    b = Tensor(np.ones([64]), dtype=ms.int32)
79    net.set_train()
80    _cell_graph_executor.compile(net, x, y, b)
81
82
83def test_onehot_model_parallel():
84    context.set_auto_parallel_context(device_num=16, global_rank=0)
85    strategy1 = ((2, 4), (4, 2))
86    strategy2 = ((2, 8),)
87    strategy3 = ((1, 16), (), ())
88    strategy4 = ((16, 1), (16, 1))
89    compile_graph(strategy1, strategy2, strategy3, strategy4)
90
91
92def test_onehot_batch_parallel():
93    context.set_auto_parallel_context(device_num=16, global_rank=0)
94    strategy1 = ((2, 4), (4, 2))
95    strategy2 = ((2, 8),)
96    strategy3 = ((16, 1), (), ())
97    strategy4 = ((16, 1), (16, 1))
98    compile_graph(strategy1, strategy2, strategy3, strategy4)
99
100
101def test_onehot_batch_parallel_invalid_strategy():
102    context.set_auto_parallel_context(device_num=16, global_rank=0)
103    strategy1 = ((2, 4), (4, 2))
104    strategy2 = ((2, 8),)
105    strategy3 = ((16,), (), ())
106    strategy4 = ((16, 1), (16, 1))
107    try:
108        compile_graph(strategy1, strategy2, strategy3, strategy4)
109    except ValueError:
110        pass
111    except TypeError:
112        pass
113    except RuntimeError:
114        pass
115
116
117def test_onehot_repeated_calculation():
118    context.set_auto_parallel_context(device_num=16, global_rank=0)
119    strategy1 = ((2, 4), (4, 2))
120    strategy2 = ((2, 8),)
121    strategy3 = ((4, 1), (), ())
122    strategy4 = ((16, 1), (16, 1))
123    compile_graph(strategy1, strategy2, strategy3, strategy4)
124
125
126def test_onehot_auto():
127    context.set_auto_parallel_context(device_num=16, global_rank=0)
128    strategy1 = None
129    strategy2 = None
130    strategy3 = None
131    strategy4 = None
132    compile_graph(strategy1, strategy2, strategy3, strategy4, auto=True)
133
134
135def test_onehot_batch_parallel_axis0():
136    context.set_auto_parallel_context(device_num=16, global_rank=0)
137    strategy1 = ((2, 4), (4, 2))
138    strategy2 = ((2, 8),)
139    strategy3 = ((16, 1), (), ())
140    strategy4 = ((16, 1), (16, 1))
141    compile_graph(strategy1, strategy2, strategy3, strategy4, onthot_axis=0)
142
143
144# auto parallel for onehot axis equal to 0 has not been supported yet
145def test_onehot_batch_parallel_invalid_strategy_axis0():
146    context.set_auto_parallel_context(device_num=16, global_rank=0)
147    strategy1 = ((2, 4), (4, 2))
148    strategy2 = ((2, 8),)
149    strategy3 = None
150    strategy4 = ((16, 1), (16, 1))
151    try:
152        compile_graph(strategy1, strategy2, strategy3, strategy4, onthot_axis=0)
153    except ValueError:
154        pass
155    except TypeError:
156        pass
157    except RuntimeError:
158        pass
159
160
161def test_onehot_repeated_calculation_axis0():
162    context.set_auto_parallel_context(device_num=16, global_rank=0)
163    strategy1 = ((2, 4), (4, 2))
164    strategy2 = ((2, 8),)
165    strategy3 = ((4, 1), (), ())
166    strategy4 = ((16, 1), (16, 1))
167    compile_graph(strategy1, strategy2, strategy3, strategy4, onthot_axis=0)
168
169
170def test_onehot_auto_axis0():
171    context.set_auto_parallel_context(device_num=16, global_rank=14)
172    strategy1 = None
173    strategy2 = None
174    strategy3 = None
175    strategy4 = None
176    compile_graph(strategy1, strategy2, strategy3, strategy4, auto=True, onthot_axis=0)
177