• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""test_mix_precision"""
16import numpy as np
17
18import mindspore.common.dtype as mstype
19import mindspore.nn as nn
20from mindspore import Tensor, context
21from mindspore.common import ParameterTuple
22from mindspore.common.api import _cell_graph_executor
23from mindspore.common.parameter import Parameter
24from mindspore.nn import Momentum
25from mindspore.nn import TrainOneStepCell, WithLossCell
26from mindspore.ops import composite as C
27from mindspore.ops import operations as P
28from mindspore.ops import functional as F
29from mindspore.context import ParallelMode
30from tests.ops_common import convert
31from ....train_step_wrap import train_step_with_loss_warp
32
33
34class LeNet5(nn.Cell):
35    """LeNet5"""
36
37    def __init__(self):
38        super(LeNet5, self).__init__()
39        self.conv1 = nn.Conv2d(1, 6, 5, pad_mode='valid')
40        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
41        self.fc1 = nn.Dense(16 * 5 * 5, 120)
42        self.fc2 = nn.Dense(120, 84)
43        self.fc3 = nn.Dense(84, 10)
44        self.relu = nn.ReLU()
45        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
46        self.flatten = P.Flatten()
47
48    def construct(self, x):
49        x = self.max_pool2d(self.relu(self.conv1(x)))
50        x = self.max_pool2d(self.relu(self.conv2(x)))
51        x = self.flatten(x)
52        x = self.relu(self.fc1(x))
53        x = self.relu(self.fc2(x))
54        x = self.fc3(x)
55        return x
56
57
58class NetForConcat(nn.Cell):
59    def __init__(self):
60        super(NetForConcat, self).__init__()
61        self.concat = P.Concat()
62        self.x1 = Tensor(np.zeros([1, 10]).astype(np.float32))
63        self.x2 = Parameter(Tensor(np.zeros([1, 10]).astype(np.float32)), name='x2')
64
65    def construct(self, x0):
66        return self.concat((x0, self.x1, self.x2))
67
68
69def test_add_cast_flag():
70    predict = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
71    label = Tensor(np.zeros([1, 10]).astype(np.float32))
72    net = LeNet5()
73    net.to_float(mstype.float16)
74    net.fc3.to_float(mstype.float32)
75    net = train_step_with_loss_warp(net)
76    net.set_train()
77    net(predict, label)
78
79
80def test_add_cast_flag_tensor():
81    x1 = Tensor(np.zeros([1, 10]).astype(np.float32))
82    net = NetForConcat()
83    net.add_flags_recursive(fp16=True)
84    net.set_train()
85    net(x1)
86
87
88def test_on_momentum():
89    predict = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
90    label = Tensor(np.zeros([1, 10]).astype(np.float32))
91    net = LeNet5()
92    net = train_step_with_loss_warp(net).to_float(mstype.float16)
93    net.set_train()
94    net(predict, label)
95
96
97def test_data_parallel_with_cast():
98    """test_data_parallel_with_cast"""
99    context.reset_auto_parallel_context()
100    context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, device_num=8)
101    predict = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
102    label = Tensor(np.zeros([1, 10]).astype(np.float32))
103    net = LeNet5()
104    net.to_float(mstype.float16)
105    net.fc3.to_float(mstype.float32)
106    loss_fn = nn.SoftmaxCrossEntropyWithLogits()
107
108    optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()),
109                         learning_rate=0.1,
110                         momentum=0.9)
111    net = WithLossCell(net, loss_fn)
112    net = TrainOneStepCell(net, optimizer)
113
114    _cell_graph_executor.compile(net, predict, label)
115    context.reset_auto_parallel_context()
116
117
118class NetForPReLU(nn.Cell):
119    def __init__(self):
120        super(NetForPReLU, self).__init__()
121        self.prelu = nn.PReLU()
122
123    def construct(self, x):
124        return self.prelu(x)
125
126
127def test_nn_prelu():
128    x = Tensor(np.ones([1, 16, 10, 10]).astype(np.float32) * 0.01)
129    net = NetForPReLU().set_train()
130    net.add_flags_recursive(fp16=True)
131    _cell_graph_executor.compile(net, x)
132
133
134class NetForCast(nn.Cell):
135    def __init__(self):
136        super(NetForCast, self).__init__()
137        self.x1 = Tensor(1.0, mstype.float32)
138        self.x2 = Parameter(Tensor(np.zeros([1, 10]).astype(np.float32)), name='x2')
139
140    def construct(self, x0):
141        x = self.x1 * x0 * self.x2
142        return x
143
144
145def test_cast():
146    x = Tensor(np.ones([1, 16, 10, 10]).astype(np.float32) * 0.01)
147    net = NetForCast()
148    net.add_flags_recursive(fp16=True)
149    net(x)
150
151
152class IRBlockZ(nn.Cell):
153    def __init__(self, inplanes, planes):
154        super(IRBlockZ, self).__init__()
155        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, pad_mode="same", group=1, has_bias=False,
156                               dilation=1)
157        self.act_layer = nn.PReLU(planes)
158
159    def construct(self, x):
160        out = self.conv1(x)
161        return self.act_layer(out)
162
163
164class GetParamGrad(nn.Cell):
165    def __init__(self, network):
166        super(GetParamGrad, self).__init__(auto_prefix=False)
167        self.network = network
168        self.weights = ParameterTuple(network.trainable_params())
169        self.grad = C.GradOperation(get_by_list=True,
170                                    sens_param=True)
171
172    def construct(self, data, sens):
173        weights = self.weights
174        return self.grad(self.network, weights)(data, sens)
175
176
177def test_grad_conv_prelu():
178    shapes = [[64, 64, 112, 112]]
179    outshape = [[64, 64, 112, 112]]
180    net = IRBlockZ(inplanes=64, planes=64).add_flags_recursive(fp16=True)
181    inputs = [convert(shp, dtype=np.float16) for shp in shapes]
182    sens_shape = outshape[0]
183    sens = convert(sens_shape, dtype=np.float16)
184    all_inputs = inputs + [sens]
185    net = GetParamGrad(net)
186    net.set_train()
187    net(*all_inputs)
188
189
190def test_dict_cast():
191    class FirstNet(nn.Cell):
192        def __init__(self):
193            super(FirstNet, self).__init__()
194            self.net = SecondNet()
195            self.sub = P.Sub()
196
197        def construct(self, tensor_a, tensor_b):
198            a = F.mixed_precision_cast(mstype.float16, tensor_a)
199            b = F.mixed_precision_cast(mstype.float16, tensor_b)
200            c = self.sub(a, b)
201            dictionary = {"key": a}
202            result = self.net(c, key1=a, key2=dictionary)
203            return result
204
205    class SecondNet(nn.Cell):
206        def __init__(self):
207            super(SecondNet, self).__init__()
208            self.add = P.Add()
209
210        def construct(self, tensor_c, **kwargs):
211            d = F.mixed_precision_cast(mstype.float16, tensor_c)
212            dict_cast = F.mixed_precision_cast(mstype.float16, kwargs)
213            e = self.add(d, dict_cast["key1"])
214            f = self.add(e, dict_cast["key2"]["key"])
215            return f
216
217    x = Tensor(np.array([1, 2.5, 3.5]), mstype.float32)
218    y = Tensor(np.array([4, 5.5, 6.5]), mstype.float32)
219    net = FirstNet()
220    net(x, y)
221
222
223def test_kwarg_cast():
224    class FirstNet(nn.Cell):
225        def __init__(self):
226            super(FirstNet, self).__init__()
227            self.net = SecondNet().add_flags_recursive(fp16=True)
228            self.add = P.Add()
229
230        def construct(self, tensor_a, tensor_b):
231            tensor_c = self.add(tensor_a, tensor_b)
232            dictionary = {"key": tensor_a}
233            result = self.net(key1=tensor_c, key2=dictionary)
234            return result
235
236    class SecondNet(nn.Cell):
237        def __init__(self):
238            super(SecondNet, self).__init__()
239            self.add = P.Add()
240
241        def construct(self, key1=1, key2=2):
242            tensor_d = self.add(key1, key2["key"])
243            return tensor_d
244
245    x = Tensor(np.array([1, 2.5, 3.5]), mstype.float32)
246    y = Tensor(np.array([4, 5.5, 6.5]), mstype.float32)
247    net = FirstNet()
248    net(x, y)
249