1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" infer """ 16from argparse import ArgumentParser 17import numpy as np 18 19from mindspore import Tensor 20from ....dataset_mock import MindData 21 22__factory = { 23 "resnet50": resnet50(), 24} 25 26 27def parse_args(): 28 """ parse_args """ 29 parser = ArgumentParser(description="resnet50 example") 30 31 parser.add_argument("--model", type=str, default="resnet50", 32 help="the network architecture for training or testing") 33 parser.add_argument("--phase", type=str, default="test", 34 help="the phase of the model, default is test.") 35 parser.add_argument("--file_path", type=str, default="/data/file/test1.txt", 36 help="data directory of training or testing") 37 parser.add_argument("--batch_size", type=int, default=1, 38 help="batch size for training or testing ") 39 40 return parser.parse_args() 41 42 43def get_model(name): 44 """ get_model """ 45 if name not in __factory: 46 raise KeyError("unknown model:", name) 47 return __factory[name] 48 49 50def get_dataset(batch_size=32): 51 """ get_dataset """ 52 dataset_types = np.float32 53 dataset_shapes = (batch_size, 3, 224, 224) 54 55 dataset = MindData(size=2, batch_size=batch_size, 56 np_types=dataset_types, 57 output_shapes=dataset_shapes, 58 input_indexs=(0, 1)) 59 return dataset 60 61 62# pylint: disable=unused-argument 63def test(name, file_path, batch_size): 64 """ test """ 65 network = get_model(name) 66 67 batch = get_dataset(batch_size=batch_size) 68 69 data_list = [] 70 for data in batch: 71 data_list.append(data.asnumpy()) 72 batch_data = np.concatenate(data_list, axis=0).transpose((0, 3, 1, 2)) 73 input_tensor = Tensor(batch_data) 74 print(input_tensor.shape) 75 network(input_tensor) 76 77 78if __name__ == '__main__': 79 args = parse_args() 80 if args.phase == "train": 81 raise NotImplementedError 82 test(args.model, args.file_path, args.batch_size) 83