1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test_pynative_mixed_precision_cells """ 16import pytest 17import numpy as np 18import mindspore as ms 19import mindspore.nn as nn 20import mindspore.ops.operations as P 21from mindspore import context 22from mindspore.nn import Cell 23from mindspore.nn import ReLU 24from mindspore.common.tensor import Tensor 25 26class MetaFactory: 27 def __init__(self): 28 self.device_target = context.get_context('device_target') 29 self.rank_size = None 30 self.device_id = None 31 self.global_rank_id = None 32 33class ReluTanhSoftmax(Cell, MetaFactory): 34 def __init__(self): 35 super().__init__() 36 MetaFactory.__init__(self) 37 self.relu = ReLU() 38 self.tanh = nn.Tanh() 39 self.softmax = nn.Softmax() 40 41 def construct(self, x): 42 x = self.relu(x) 43 y = self.tanh(x) 44 z = self.softmax(x) 45 return x, y, z 46 47class Add(Cell, MetaFactory): 48 def __init__(self): 49 super().__init__() 50 MetaFactory.__init__(self) 51 self.add = P.Add() 52 53 def construct(self, x, y): 54 return self.add(x, y) 55 56class ReluTanhAdd(Cell, MetaFactory): 57 def __init__(self): 58 super().__init__() 59 MetaFactory.__init__(self) 60 self.relu = ReLU() 61 self.tanh = nn.Tanh() 62 self.add = Add() 63 64 def construct(self, x): 65 x_1 = self.relu(x) 66 y = self.tanh(x) 67 x = self.add(x_1, y) 68 return x 69 70def _count_unequal_element(data_expected, data_me, rtol, atol): 71 assert data_expected.shape == data_me.shape 72 total_count = len(data_expected.flatten()) 73 error = np.abs(data_expected - data_me) 74 greater = np.greater(error, atol + np.abs(data_me)*rtol) 75 loss_count = np.count_nonzero(greater) 76 assert (loss_count/total_count) < rtol, \ 77 "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \ 78 format(data_expected[greater], data_me[greater], error[greater]) 79 80def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True): 81 if np.any(np.isnan(data_expected)): 82 assert np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan) 83 elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan): 84 _count_unequal_element(data_expected, data_me, rtol, atol) 85 else: 86 assert True 87 88def mixed_precision_multiple_cells_temp_01(): 89 np.random.seed(1) 90 x = np.random.randn(1, 3, 28, 28).astype(np.float32) 91 net = ReluTanhSoftmax() 92 net.to_float(ms.float16) 93 net.relu.to_float(ms.float32) 94 net.softmax.to_float(ms.float16) 95 out_me_relu_01, out_me_tanh_01, out_me_softmax_01 = net(Tensor(x)) 96 return out_me_relu_01, out_me_tanh_01, out_me_softmax_01 97 98def mixed_precision_multiple_cells_temp_02(): 99 np.random.seed(1) 100 x = np.random.randn(1, 3, 28, 28).astype(np.float32) 101 net = ReluTanhSoftmax() 102 net.relu.to_float(ms.float32) 103 net.softmax.to_float(ms.float16) 104 net.to_float(ms.float16) 105 out_me_relu_02, out_me_tanh_02, out_me_softmax_02 = net(Tensor(x)) 106 return out_me_relu_02, out_me_tanh_02, out_me_softmax_02 107 108def mixed_precision_multiple_cells_temp_03(): 109 np.random.seed(1) 110 x = np.random.randn(1, 3, 28, 28).astype(np.float32) 111 net = ReluTanhAdd() 112 net.to_float(ms.float16) 113 net.relu.to_float(ms.float32) 114 net.add.to_float(ms.float32) 115 out_me = net(Tensor(x)) 116 return out_me 117 118def mixed_precision_multiples_cell_01(): 119 context.set_context(mode=context.GRAPH_MODE, device_target=context.get_context('device_target')) 120 graph_relu_01, graph_tanh_01, graph_softmax_01 = mixed_precision_multiple_cells_temp_01() 121 122 context.set_context(mode=context.PYNATIVE_MODE, device_target=context.get_context('device_target')) 123 pynative_relu_01, pynative_tanh_01, pynative_softmax_01 = mixed_precision_multiple_cells_temp_01() 124 125 allclose_nparray(graph_relu_01.asnumpy(), pynative_relu_01.asnumpy(), 0.001, 0.001) 126 allclose_nparray(graph_tanh_01.asnumpy(), pynative_tanh_01.asnumpy(), 0.001, 0.001) 127 allclose_nparray(graph_softmax_01.asnumpy(), pynative_softmax_01.asnumpy(), 0.001, 0.001) 128 129def mixed_precision_multiples_cell_02(): 130 context.set_context(mode=context.GRAPH_MODE, device_target=context.get_context('device_target')) 131 graph_relu_02, graph_tanh_02, graph_softmax_02 = mixed_precision_multiple_cells_temp_02() 132 133 context.set_context(mode=context.PYNATIVE_MODE, device_target=context.get_context('device_target')) 134 pynative_relu_02, pynative_tanh_02, pynative_softmax_02 = mixed_precision_multiple_cells_temp_02() 135 136 allclose_nparray(graph_relu_02.asnumpy(), pynative_relu_02.asnumpy(), 0.001, 0.001) 137 allclose_nparray(graph_tanh_02.asnumpy(), pynative_tanh_02.asnumpy(), 0.001, 0.001) 138 allclose_nparray(graph_softmax_02.asnumpy(), pynative_softmax_02.asnumpy(), 0.001, 0.001) 139 140def mixed_precision_multiples_cell_03(): 141 context.set_context(mode=context.GRAPH_MODE, device_target=context.get_context('device_target')) 142 graph_output_03 = mixed_precision_multiple_cells_temp_03() 143 144 context.set_context(mode=context.PYNATIVE_MODE, device_target=context.get_context('device_target')) 145 pynative_output_03 = mixed_precision_multiple_cells_temp_03() 146 147 allclose_nparray(graph_output_03.asnumpy(), pynative_output_03.asnumpy(), 0.001, 0.001) 148 149@pytest.mark.level0 150@pytest.mark.platform_arm_ascend_training 151@pytest.mark.platform_x86_ascend_training 152@pytest.mark.env_onecard 153def test_mixed_precision_multiples_cell_ascend_01(): 154 context.set_context(device_target="Ascend") 155 mixed_precision_multiples_cell_01() 156 157@pytest.mark.level0 158@pytest.mark.platform_x86_gpu_training 159@pytest.mark.env_onecard 160def test_mixed_precision_multiples_cell_gpu_01(): 161 context.set_context(device_target="GPU") 162 mixed_precision_multiples_cell_01() 163 164@pytest.mark.level1 165@pytest.mark.platform_arm_ascend_training 166@pytest.mark.platform_x86_ascend_training 167@pytest.mark.env_onecard 168def test_mixed_precision_multiples_cell_ascend_02(): 169 context.set_context(device_target="Ascend") 170 mixed_precision_multiples_cell_02() 171 172@pytest.mark.level0 173@pytest.mark.platform_x86_gpu_training 174@pytest.mark.env_onecard 175def test_mixed_precision_multiples_cell_gpu_02(): 176 context.set_context(device_target="GPU") 177 mixed_precision_multiples_cell_02() 178 179@pytest.mark.level1 180@pytest.mark.platform_arm_ascend_training 181@pytest.mark.platform_x86_ascend_training 182@pytest.mark.env_onecard 183def test_mixed_precision_multiples_cell_ascend_03(): 184 context.set_context(device_target="Ascend") 185 mixed_precision_multiples_cell_03() 186 187@pytest.mark.level0 188@pytest.mark.platform_x86_gpu_training 189@pytest.mark.env_onecard 190def test_mixed_precision_multiples_cell_gpu_03(): 191 context.set_context(device_target="GPU") 192 mixed_precision_multiples_cell_03() 193