• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import os
16import pytest
17
18from tests.st.model_zoo_tests import utils
19
20
21@pytest.mark.level1
22@pytest.mark.platform_x86_ascend_training
23@pytest.mark.platform_arm_ascend_training
24@pytest.mark.env_single
25def test_center_net():
26    cur_path = os.path.dirname(os.path.abspath(__file__))
27    model_path = "{}/../../../../tests/models/research/cv".format(cur_path)
28    model_name = "centernet"
29    utils.copy_files(model_path, cur_path, model_name)
30    cur_model_path = os.path.join(cur_path, model_name)
31    old_list = ['new_repeat_count, dataset', 'args_opt.data_sink_steps']
32    new_list = ['5, dataset', '20']
33    utils.exec_sed_command(old_list, new_list, os.path.join(cur_model_path, "train.py"))
34    old_list = ["device_ips = {}", "device_ip.strip()",
35                "rank_size = 0", "this_server = server",
36                "this_server\\[\\\"device\\\"\\]",
37                "instance\\[\\\"device_id\\\"\\]"]
38    new_list = ["device_ips = {}\\n    '''", "device_ip.strip()\\n    '''",
39                "rank_size = 8\\n    this_server = hccl_config[\\\"group_list\\\"][0]\\n    '''",
40                "this_server = server\\n    '''",
41                "this_server[\\\"instance_list\\\"]",
42                "instance[\\\"devices\\\"][0][\\\"device_id\\\"]"]
43    generator_cmd_file = "scripts/ascend_distributed_launcher/get_distribute_train_cmd.py"
44    utils.exec_sed_command(old_list, new_list, os.path.join(cur_model_path, generator_cmd_file))
45    dataset_path = os.path.join(utils.data_root, "coco/coco2017/mindrecord_train/centernet_mindrecord")
46    exec_network_shell = "cd centernet; bash scripts/run_distributed_train_ascend.sh {0} {1}"\
47        .format(dataset_path, utils.rank_table_path)
48    os.system(exec_network_shell)
49    cmd = "ps -ef |grep train.py | grep coco | grep -v grep"
50    ret = utils.process_check(120, cmd)
51    assert ret
52    log_file = os.path.join(cur_model_path, "LOG{}/training_log.txt")
53    for i in range(8):
54        per_step_time = utils.get_perf_data(log_file.format(i))
55        assert per_step_time < 435
56    loss_list = []
57    for i in range(8):
58        loss_cmd = "grep -nr \"outputs are\" {} | awk '{{print $14}}' | awk -F\")\" '{{print $1}}'"\
59            .format(log_file.format(i))
60        loss = utils.get_loss_data_list(log_file.format(i), cmd=loss_cmd)
61        loss_list.append(loss[-1])
62    assert sum(loss_list) / len(loss_list) < 58.8
63