1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import argparse 17import os 18import sys 19from time import time 20from mindspore import context 21from mindspore.train.serialization import load_checkpoint 22from src.config import eval_cfg, server_net_cfg 23from src.dataset import load_datasets 24from src.utils import restore_params 25from src.model import AlbertModelCLS 26from src.tokenization import CustomizedTextTokenizer 27from src.assessment_method import Accuracy 28 29 30def parse_args(): 31 """ 32 parse args 33 """ 34 parser = argparse.ArgumentParser(description='server eval task') 35 parser.add_argument('--device_target', type=str, default='GPU', choices=['Ascend', 'GPU', 'CPU']) 36 parser.add_argument('--device_id', type=str, default='0') 37 parser.add_argument('--tokenizer_dir', type=str, default='../model_save/init/') 38 parser.add_argument('--eval_data_dir', type=str, default='../datasets/eval/') 39 parser.add_argument('--model_path', type=str, default='../model_save/train_server/0.ckpt') 40 parser.add_argument('--vocab_map_ids_path', type=str, default='../model_save/init/vocab_map_ids.txt') 41 42 return parser.parse_args() 43 44 45def server_eval(args): 46 start = time() 47 # some parameters 48 os.environ['CUDA_VISIBLE_DEVICES'] = args.device_id 49 tokenizer_dir = args.tokenizer_dir 50 eval_data_dir = args.eval_data_dir 51 model_path = args.model_path 52 vocab_map_ids_path = args.vocab_map_ids_path 53 54 # mindspore context 55 context.set_context(mode=context.GRAPH_MODE, device_target='GPU') 56 print('Context setting is done! Time cost: {}'.format(time() - start)) 57 sys.stdout.flush() 58 start = time() 59 60 # data process 61 tokenizer = CustomizedTextTokenizer.from_pretrained(tokenizer_dir, vocab_map_ids_path=vocab_map_ids_path) 62 datasets_list, _ = load_datasets( 63 eval_data_dir, server_net_cfg.seq_length, tokenizer, eval_cfg.batch_size, 64 label_list=None, 65 do_shuffle=False, 66 drop_remainder=False, 67 output_dir=None) 68 print('Data process is done! Time cost: {}'.format(time() - start)) 69 sys.stdout.flush() 70 start = time() 71 72 # main model 73 albert_model_cls = AlbertModelCLS(server_net_cfg) 74 albert_model_cls.set_train(False) 75 param_dict = load_checkpoint(model_path) 76 restore_params(albert_model_cls, param_dict) 77 print('Model construction is done! Time cost: {}'.format(time() - start)) 78 sys.stdout.flush() 79 start = time() 80 81 # eval 82 callback = Accuracy() 83 global_step = 0 84 for datasets in datasets_list: 85 for batch in datasets.create_tuple_iterator(): 86 input_ids, attention_mask, token_type_ids, label_ids, _ = batch 87 logits = albert_model_cls(input_ids, attention_mask, token_type_ids) 88 callback.update(logits, label_ids) 89 print('eval step: {}, {}: {}'.format(global_step, callback.name, callback.get_metrics())) 90 sys.stdout.flush() 91 global_step += 1 92 metrics = callback.get_metrics() 93 print('Final {}: {}'.format(callback.name, metrics)) 94 sys.stdout.flush() 95 print('Evaluating process is done! Time cost: {}'.format(time() - start)) 96 sys.stdout.flush() 97 98 99if __name__ == '__main__': 100 args_opt = parse_args() 101 server_eval(args_opt) 102