• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test_pynative_mixed_precision_cells """
16import pytest
17import numpy as np
18import mindspore as ms
19import mindspore.nn as nn
20import mindspore.ops.operations as P
21from mindspore import context
22from mindspore.nn import Cell
23from mindspore.nn import ReLU
24from mindspore.common.tensor import Tensor
25
26class MetaFactory:
27    def __init__(self):
28        self.device_target = context.get_context('device_target')
29        self.rank_size = None
30        self.device_id = None
31        self.global_rank_id = None
32
33class ReluTanhSoftmax(Cell, MetaFactory):
34    def __init__(self):
35        super().__init__()
36        MetaFactory.__init__(self)
37        self.relu = ReLU()
38        self.tanh = nn.Tanh()
39        self.softmax = nn.Softmax()
40
41    def construct(self, x):
42        x = self.relu(x)
43        y = self.tanh(x)
44        z = self.softmax(x)
45        return x, y, z
46
47class Add(Cell, MetaFactory):
48    def __init__(self):
49        super().__init__()
50        MetaFactory.__init__(self)
51        self.add = P.Add()
52
53    def construct(self, x, y):
54        return self.add(x, y)
55
56class ReluTanhAdd(Cell, MetaFactory):
57    def __init__(self):
58        super().__init__()
59        MetaFactory.__init__(self)
60        self.relu = ReLU()
61        self.tanh = nn.Tanh()
62        self.add = Add()
63
64    def construct(self, x):
65        x_1 = self.relu(x)
66        y = self.tanh(x)
67        x = self.add(x_1, y)
68        return x
69
70def _count_unequal_element(data_expected, data_me, rtol, atol):
71    assert data_expected.shape == data_me.shape
72    total_count = len(data_expected.flatten())
73    error = np.abs(data_expected - data_me)
74    greater = np.greater(error, atol + np.abs(data_me)*rtol)
75    loss_count = np.count_nonzero(greater)
76    assert (loss_count/total_count) < rtol, \
77        "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
78            format(data_expected[greater], data_me[greater], error[greater])
79
80def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True):
81    if np.any(np.isnan(data_expected)):
82        assert np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan)
83    elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan):
84        _count_unequal_element(data_expected, data_me, rtol, atol)
85    else:
86        assert True
87
88def mixed_precision_multiple_cells_temp_01():
89    np.random.seed(1)
90    x = np.random.randn(1, 3, 28, 28).astype(np.float32)
91    net = ReluTanhSoftmax()
92    net.to_float(ms.float16)
93    net.relu.to_float(ms.float32)
94    net.softmax.to_float(ms.float16)
95    out_me_relu_01, out_me_tanh_01, out_me_softmax_01 = net(Tensor(x))
96    return out_me_relu_01, out_me_tanh_01, out_me_softmax_01
97
98def mixed_precision_multiple_cells_temp_02():
99    np.random.seed(1)
100    x = np.random.randn(1, 3, 28, 28).astype(np.float32)
101    net = ReluTanhSoftmax()
102    net.relu.to_float(ms.float32)
103    net.softmax.to_float(ms.float16)
104    net.to_float(ms.float16)
105    out_me_relu_02, out_me_tanh_02, out_me_softmax_02 = net(Tensor(x))
106    return out_me_relu_02, out_me_tanh_02, out_me_softmax_02
107
108def mixed_precision_multiple_cells_temp_03():
109    np.random.seed(1)
110    x = np.random.randn(1, 3, 28, 28).astype(np.float32)
111    net = ReluTanhAdd()
112    net.to_float(ms.float16)
113    net.relu.to_float(ms.float32)
114    net.add.to_float(ms.float32)
115    out_me = net(Tensor(x))
116    return out_me
117
118def mixed_precision_multiples_cell_01():
119    context.set_context(mode=context.GRAPH_MODE, device_target=context.get_context('device_target'))
120    graph_relu_01, graph_tanh_01, graph_softmax_01 = mixed_precision_multiple_cells_temp_01()
121
122    context.set_context(mode=context.PYNATIVE_MODE, device_target=context.get_context('device_target'))
123    pynative_relu_01, pynative_tanh_01, pynative_softmax_01 = mixed_precision_multiple_cells_temp_01()
124
125    allclose_nparray(graph_relu_01.asnumpy(), pynative_relu_01.asnumpy(), 0.001, 0.001)
126    allclose_nparray(graph_tanh_01.asnumpy(), pynative_tanh_01.asnumpy(), 0.001, 0.001)
127    allclose_nparray(graph_softmax_01.asnumpy(), pynative_softmax_01.asnumpy(), 0.001, 0.001)
128
129def mixed_precision_multiples_cell_02():
130    context.set_context(mode=context.GRAPH_MODE, device_target=context.get_context('device_target'))
131    graph_relu_02, graph_tanh_02, graph_softmax_02 = mixed_precision_multiple_cells_temp_02()
132
133    context.set_context(mode=context.PYNATIVE_MODE, device_target=context.get_context('device_target'))
134    pynative_relu_02, pynative_tanh_02, pynative_softmax_02 = mixed_precision_multiple_cells_temp_02()
135
136    allclose_nparray(graph_relu_02.asnumpy(), pynative_relu_02.asnumpy(), 0.001, 0.001)
137    allclose_nparray(graph_tanh_02.asnumpy(), pynative_tanh_02.asnumpy(), 0.001, 0.001)
138    allclose_nparray(graph_softmax_02.asnumpy(), pynative_softmax_02.asnumpy(), 0.001, 0.001)
139
140def mixed_precision_multiples_cell_03():
141    context.set_context(mode=context.GRAPH_MODE, device_target=context.get_context('device_target'))
142    graph_output_03 = mixed_precision_multiple_cells_temp_03()
143
144    context.set_context(mode=context.PYNATIVE_MODE, device_target=context.get_context('device_target'))
145    pynative_output_03 = mixed_precision_multiple_cells_temp_03()
146
147    allclose_nparray(graph_output_03.asnumpy(), pynative_output_03.asnumpy(), 0.001, 0.001)
148
149@pytest.mark.level0
150@pytest.mark.platform_arm_ascend_training
151@pytest.mark.platform_x86_ascend_training
152@pytest.mark.env_onecard
153def test_mixed_precision_multiples_cell_ascend_01():
154    context.set_context(device_target="Ascend")
155    mixed_precision_multiples_cell_01()
156
157@pytest.mark.level0
158@pytest.mark.platform_x86_gpu_training
159@pytest.mark.env_onecard
160def test_mixed_precision_multiples_cell_gpu_01():
161    context.set_context(device_target="GPU")
162    mixed_precision_multiples_cell_01()
163
164@pytest.mark.level1
165@pytest.mark.platform_arm_ascend_training
166@pytest.mark.platform_x86_ascend_training
167@pytest.mark.env_onecard
168def test_mixed_precision_multiples_cell_ascend_02():
169    context.set_context(device_target="Ascend")
170    mixed_precision_multiples_cell_02()
171
172@pytest.mark.level0
173@pytest.mark.platform_x86_gpu_training
174@pytest.mark.env_onecard
175def test_mixed_precision_multiples_cell_gpu_02():
176    context.set_context(device_target="GPU")
177    mixed_precision_multiples_cell_02()
178
179@pytest.mark.level1
180@pytest.mark.platform_arm_ascend_training
181@pytest.mark.platform_x86_ascend_training
182@pytest.mark.env_onecard
183def test_mixed_precision_multiples_cell_ascend_03():
184    context.set_context(device_target="Ascend")
185    mixed_precision_multiples_cell_03()
186
187@pytest.mark.level0
188@pytest.mark.platform_x86_gpu_training
189@pytest.mark.env_onecard
190def test_mixed_precision_multiples_cell_gpu_03():
191    context.set_context(device_target="GPU")
192    mixed_precision_multiples_cell_03()
193