• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Function:
17    test network
18Usage:
19    python test_network_main.py --net lenet --target Ascend
20"""
21import argparse
22
23import numpy as np
24from models.alexnet import AlexNet
25from models.lenet import LeNet
26from models.resnetv1_5 import resnet50
27
28import mindspore.context as context
29import mindspore.nn as nn
30from mindspore import Tensor
31from mindspore.nn import TrainOneStepCell, WithLossCell
32from mindspore.nn.optim import Momentum
33
34context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
35
36
37def train(net, data, label):
38    learning_rate = 0.01
39    momentum = 0.9
40
41    optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
42    criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
43    net_with_criterion = WithLossCell(net, criterion)
44    train_network = TrainOneStepCell(net_with_criterion, optimizer)  # optimizer
45    train_network.set_train()
46    res = train_network(data, label)
47    print(res)
48    assert res
49
50
51def test_resnet50():
52    data = Tensor(np.ones([32, 3, 224, 224]).astype(np.float32) * 0.01)
53    label = Tensor(np.ones([32]).astype(np.int32))
54    net = resnet50(32, 10)
55    train(net, data, label)
56
57
58def test_lenet():
59    net = LeNet()
60    data = Tensor(np.ones([net.batch_size, 3, 32, 32]).astype(np.float32) * 0.01)
61    label = Tensor(np.ones([net.batch_size]).astype(np.int32))
62    train(net, data, label)
63
64
65def test_alexnet():
66    data = Tensor(np.ones([32, 3, 227, 227]).astype(np.float32) * 0.01)
67    label = Tensor(np.ones([32]).astype(np.int32))
68    net = AlexNet()
69    train(net, data, label)
70
71
72parser = argparse.ArgumentParser(description='MindSpore Testing Network')
73parser.add_argument('--net', default='resnet50', type=str, help='net name')
74parser.add_argument('--device', default='Ascend', type=str, help='device target')
75if __name__ == "__main__":
76    args = parser.parse_args()
77    context.set_context(device_target=args.device)
78    if args.net == 'resnet50':
79        test_resnet50()
80    elif args.net == 'lenet':
81        test_lenet()
82    elif args.net == 'alexnet':
83        test_alexnet()
84    else:
85        print("Please add net name like --net lenet")
86