1import numpy as np 2 3from mindspore import Tensor 4from mindspore.rewrite import SymbolTree 5from tests.models.official.cv.mobilenetv2.src.mobilenetV2 import MobileNetV2Backbone, MobileNetV2Head, mobilenet_v2 6 7 8def define_net(): 9 backbone_net = MobileNetV2Backbone() 10 activation = "None" 11 head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, 12 num_classes=2, 13 activation=activation) 14 net = mobilenet_v2(backbone_net, head_net) 15 return backbone_net, head_net, net 16 17 18def test_mobilenet(): 19 """ 20 Feature: Test Rewrite. 21 Description: Test Rewrite on Mobilenetv2. 22 Expectation: Success. 23 """ 24 _, _, net = define_net() 25 predict = Tensor(np.ones([1, 3, 224, 224]).astype(np.float32)) 26 expect = net(predict) 27 stree = SymbolTree.create(net) 28 net_opt = stree.get_network() 29 output = net_opt(predict) 30 assert np.allclose(output.asnumpy(), expect.asnumpy()) 31