• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import numpy as np
2
3from mindspore import Tensor
4from mindspore.rewrite import SymbolTree
5from tests.models.official.cv.mobilenetv2.src.mobilenetV2 import MobileNetV2Backbone, MobileNetV2Head, mobilenet_v2
6
7
8def define_net():
9    backbone_net = MobileNetV2Backbone()
10    activation = "None"
11    head_net = MobileNetV2Head(input_channel=backbone_net.out_channels,
12                               num_classes=2,
13                               activation=activation)
14    net = mobilenet_v2(backbone_net, head_net)
15    return backbone_net, head_net, net
16
17
18def test_mobilenet():
19    """
20    Feature: Test Rewrite.
21    Description: Test Rewrite on Mobilenetv2.
22    Expectation: Success.
23    """
24    _, _, net = define_net()
25    predict = Tensor(np.ones([1, 3, 224, 224]).astype(np.float32))
26    expect = net(predict)
27    stree = SymbolTree.create(net)
28    net_opt = stree.get_network()
29    output = net_opt(predict)
30    assert np.allclose(output.asnumpy(), expect.asnumpy())
31