1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import mindspore.nn as nn 17from mindspore.rewrite import SymbolTree, Node, ScopedValue 18from mindspore import Tensor 19from mindspore.common.api import _cell_graph_executor 20import numpy as np 21import pytest 22 23 24class LeNet5(nn.Cell): 25 """ 26 Args: 27 num_class (int): Number of classes. Default: 10. 28 num_channel (int): Number of channels. Default: 1. 29 """ 30 def __init__(self, num_class=10, num_channel=1, include_top=True): 31 super(LeNet5, self).__init__() 32 self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') 33 self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') 34 self.relu = nn.ReLU() 35 self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) 36 self.include_top = include_top 37 if self.include_top: 38 self.flatten = nn.Flatten() 39 self.fc1 = nn.Dense(16 * 5 * 5, 120) 40 self.fc2 = nn.Dense(120, 84) 41 self.fc3 = nn.Dense(84, num_class) 42 43 def construct(self, x): 44 x = self.conv1(x) 45 x = self.relu(x) 46 x = self.max_pool2d(x) 47 x = self.conv2(x) 48 x = self.relu(x) 49 x = self.max_pool2d(x) 50 if not self.include_top: 51 return x 52 x = self.flatten(x) 53 x = self.relu(self.fc1(x)) 54 x = self.relu(self.fc2(x)) 55 x = self.fc3(x) 56 return x 57 58 59@pytest.mark.level1 60@pytest.mark.platform_x86_cpu 61def test_rewrite_apis(): 62 """ 63 Feature: Test rewrite apis. 64 Description: Test rewrite SymbolTree and Node apis. 65 Expectation: success. 66 """ 67 net = LeNet5() 68 stree = SymbolTree.create(net) 69 assert isinstance(stree, SymbolTree) is True 70 assert len(list(stree.nodes())) == 16 71 conv1_node = stree.get_node('conv1') 72 assert isinstance(conv1_node, Node) is True 73 node_name = conv1_node.get_name() 74 assert node_name == 'conv1' 75 position = stree.after(conv1_node) 76 new_node = Node.create_call_cell(cell=nn.ReLU(), targets=['x_1'], 77 args=[ScopedValue.create_naming_value('x')], name='new_relu') 78 for user in conv1_node.get_users(): 79 user.set_arg(0, new_node.get_targets()[0]) 80 stree.insert(position, new_node) 81 assert conv1_node.get_users()[0] == new_node 82 assert new_node.get_inputs()[0] == conv1_node 83 assert len(list(stree.nodes())) == 17 84 conv2_node = stree.get_node('conv2') 85 position = stree.before(conv2_node) 86 new_node2 = Node.create_call_cell(cell=nn.ReLU(), targets=['x_2'], 87 args=[ScopedValue.create_naming_value('x')], name='new_relu2') 88 conv2_node.set_arg_by_node(0, new_node2, 0) 89 stree.insert(position, new_node2) 90 assert new_node2.get_users()[0] == conv2_node 91 assert conv2_node.get_inputs()[0] == new_node2 92 relu_node = stree.get_node("relu") 93 assert len(list(stree.nodes())) == 18 94 assert "relu" in [node.get_name() for node in stree.nodes()] 95 stree.erase(relu_node) 96 assert len(list(stree.nodes())) == 17 97 assert "relu" not in [node.get_name() for node in stree.nodes()] 98 new_node3 = Node.create_call_cell(cell=nn.Flatten(), targets=[stree.unique_name('x')], 99 args=[ScopedValue.create_naming_value('x')], name='new_flatten') 100 assert new_node3.get_targets()[0] == ScopedValue.create_naming_value('x_3') 101 flatten_node = None 102 for node in stree.nodes(): 103 if node.get_instance_type() == nn.Flatten: 104 flatten_node = node 105 break 106 assert flatten_node is not None 107 for user in flatten_node.get_users(): 108 user.set_arg_by_node(0, new_node3, 0) 109 assert "flatten" in [node.get_name() for node in stree.nodes()] 110 stree.replace(flatten_node, [new_node3]) 111 assert "flatten" not in [node.get_name() for node in stree.nodes()] 112 assert "new_flatten" in [node.get_name() for node in stree.nodes()] 113 codes = stree.get_code() 114 assert codes.find("self.new_relu") 115 assert codes.find("self.new_relu2") 116 assert codes.find("self.new_relu3") 117 net = stree.get_network() 118 data_in = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01) 119 _cell_graph_executor.compile(net, data_in) 120