1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14import os 15import numpy as np 16import pytest 17 18import mindspore.nn as nn 19from mindspore import context, ms_function 20from mindspore.common.tensor import Tensor 21from mindspore.train.serialization import export, load 22 23 24class SingleWhileNet(nn.Cell): 25 def construct(self, x, y): 26 x += 1 27 while x < y: 28 x += 1 29 y += 2 * x 30 return y 31 32 33@pytest.mark.level0 34@pytest.mark.platform_x86_ascend_training 35@pytest.mark.platform_arm_ascend_training 36@pytest.mark.env_onecard 37def test_single_while(): 38 context.set_context(mode=context.GRAPH_MODE) 39 network = SingleWhileNet() 40 41 x = Tensor(np.array([1]).astype(np.float32)) 42 y = Tensor(np.array([2]).astype(np.float32)) 43 origin_out = network(x, y) 44 45 file_name = "while_net" 46 export(network, x, y, file_name=file_name, file_format='MINDIR') 47 mindir_name = file_name + ".mindir" 48 assert os.path.exists(mindir_name) 49 50 graph = load(mindir_name) 51 loaded_net = nn.GraphCell(graph) 52 outputs_after_load = loaded_net(x, y) 53 assert origin_out == outputs_after_load 54 55@pytest.mark.level0 56@pytest.mark.platform_x86_ascend_training 57@pytest.mark.platform_arm_ascend_training 58@pytest.mark.env_onecard 59def test_ms_function_while(): 60 context.set_context(mode=context.PYNATIVE_MODE) 61 network = SingleWhileNet() 62 63 x = Tensor(np.array([1]).astype(np.float32)) 64 y = Tensor(np.array([2]).astype(np.float32)) 65 origin_out = network(x, y) 66 67 file_name = "while_net" 68 export(network, x, y, file_name=file_name, file_format='MINDIR') 69 mindir_name = file_name + ".mindir" 70 assert os.path.exists(mindir_name) 71 72 graph = load(mindir_name) 73 loaded_net = nn.GraphCell(graph) 74 @ms_function 75 def run_graph(x, y): 76 outputs = loaded_net(x, y) 77 return outputs 78 outputs_after_load = run_graph(x, y) 79 assert origin_out == outputs_after_load 80 81 82class SingleWhileInlineNet(nn.Cell): 83 def construct(self, x, y): 84 x += 1 85 while x < y: 86 x += 1 87 y += x 88 return y 89 90 91@pytest.mark.level0 92@pytest.mark.platform_x86_ascend_training 93@pytest.mark.platform_arm_ascend_training 94@pytest.mark.env_onecard 95def test_single_while_inline_export(): 96 context.set_context(mode=context.GRAPH_MODE) 97 network = SingleWhileInlineNet() 98 99 x = Tensor(np.array([1]).astype(np.float32)) 100 y = Tensor(np.array([2]).astype(np.float32)) 101 102 file_name = "while_inline_net" 103 export(network, x, y, file_name=file_name, file_format='MINDIR') 104 mindir_name = file_name + ".mindir" 105 assert os.path.exists(mindir_name) 106 107 108@pytest.mark.level0 109@pytest.mark.platform_x86_ascend_training 110@pytest.mark.platform_arm_ascend_training 111@pytest.mark.env_onecard 112def test_single_while_inline_load(): 113 context.set_context(mode=context.GRAPH_MODE) 114 network = SingleWhileInlineNet() 115 116 x = Tensor(np.array([1]).astype(np.float32)) 117 y = Tensor(np.array([2]).astype(np.float32)) 118 119 file_name = "while_inline_net" 120 export(network, x, y, file_name=file_name, file_format='MINDIR') 121 mindir_name = file_name + ".mindir" 122 assert os.path.exists(mindir_name) 123 load(mindir_name) 124 125@pytest.mark.level0 126@pytest.mark.platform_x86_ascend_training 127@pytest.mark.platform_arm_ascend_training 128@pytest.mark.env_onecard 129def test_single_while_inline(): 130 context.set_context(mode=context.GRAPH_MODE) 131 network = SingleWhileInlineNet() 132 133 x = Tensor(np.array([1]).astype(np.float32)) 134 y = Tensor(np.array([2]).astype(np.float32)) 135 origin_out = network(x, y) 136 137 file_name = "while_inline_net" 138 export(network, x, y, file_name=file_name, file_format='MINDIR') 139 mindir_name = file_name + ".mindir" 140 assert os.path.exists(mindir_name) 141 142 graph = load(mindir_name) 143 loaded_net = nn.GraphCell(graph) 144 outputs_after_load = loaded_net(x, y) 145 assert origin_out == outputs_after_load 146