1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include <list> 17 #include "common/common_test.h" 18 #include "frontend/parallel/device.h" 19 #include "frontend/parallel/device_manager.h" 20 #include "frontend/parallel/group_manager.h" 21 22 namespace mindspore { 23 namespace parallel { 24 25 class TestDevice : public UT::Common { 26 public: 27 TestDevice() {} 28 void SetUp(); 29 void TearDown(); 30 Device dev_1; 31 Device dev_2; 32 }; 33 34 void TestDevice::SetUp() { 35 std::string name = "#1"; 36 dev_1 = Device(name, std::int32_t(1)); 37 dev_2 = Device(std::int32_t(2)); 38 } 39 40 void TestDevice::TearDown() { 41 // destroy resources 42 } 43 44 TEST_F(TestDevice, test_device) { 45 std::string name = "#1"; 46 int32_t dev1_rank = 1; 47 int32_t dev2_rank = 2; 48 49 ASSERT_STREQ(dev_1.name().data(), name.data()); 50 ASSERT_EQ(dev_1.rank(), dev1_rank); 51 ASSERT_EQ(dev_2.rank(), dev2_rank); 52 } 53 54 // need to complete 55 class TestStage : public UT::Common {}; 56 57 class TestDeviceManager : public UT::Common { 58 public: 59 TestDeviceManager() {} 60 void SetUp(); 61 void TearDown(); 62 DeviceManager dm_; 63 }; 64 65 void TestDeviceManager::SetUp() { dm_ = DeviceManager::GetInstance(); } 66 67 void TestDeviceManager::TearDown() { 68 // destroy resources 69 } 70 71 TEST_F(TestDeviceManager, test_dm_init_AND_get_device_list) { 72 RankList dev_list; 73 RankList stage_map; 74 int32_t local_dev = 0; 75 76 dev_list.push_back(5); 77 dev_list.push_back(3); 78 dev_list.push_back(1); 79 dev_list.push_back(0); 80 81 stage_map.push_back(2); 82 stage_map.push_back(2); 83 ASSERT_EQ(dm_.Init(dev_list, local_dev, stage_map, "hccl"), Status::SUCCESS); 84 85 ASSERT_EQ(dm_.DeviceNum(), 4); 86 ASSERT_EQ(dm_.stage_num(), (int32_t)(2)); 87 88 RankList dev_list_0 = dm_.GetDeviceListByStageId(0); 89 RankList dev_list_1 = dm_.GetDeviceListByStageId(1); 90 ASSERT_EQ(dev_list_0.size(), 2); 91 ASSERT_EQ(dev_list_1.size(), 2); 92 93 RankList::iterator it = dev_list_0.begin(); 94 ASSERT_EQ((*it), int32_t(5)); 95 it++; 96 ASSERT_EQ((*it), int32_t(3)); 97 it = dev_list_1.begin(); 98 ASSERT_EQ((*it), int32_t(1)); 99 it++; 100 ASSERT_EQ((*it), int32_t(0)); 101 } 102 103 TEST_F(TestDeviceManager, test_CreateNewDeviceByRank) { 104 Device one = dm_.CreateNewDeviceByRank(int32_t(3)); 105 ASSERT_EQ(one.rank(), int32_t(3)); 106 } 107 108 TEST_F(TestDeviceManager, test_CreateDeviceListByRankList) { 109 std::vector<Device> dev_list; 110 RankList rlist; 111 rlist.push_back(int32_t(2)); 112 rlist.push_back(int32_t(1)); 113 dev_list = dm_.CreateDeviceListByRankList(rlist); 114 115 std::vector<Device>::iterator it = dev_list.begin(); 116 ASSERT_EQ(it->rank(), int32_t(2)); 117 it++; 118 ASSERT_EQ(it->rank(), int32_t(1)); 119 } 120 121 TEST_F(TestDeviceManager, test_StageID) { 122 RankList dev_list; 123 RankList stage_map; 124 int32_t local_dev = 2; 125 126 dev_list.push_back(0); 127 dev_list.push_back(1); 128 dev_list.push_back(2); 129 dev_list.push_back(3); 130 131 stage_map.push_back(2); 132 stage_map.push_back(2); 133 ASSERT_EQ(dm_.Init(dev_list, local_dev, stage_map, "hccl"), Status::SUCCESS); 134 135 ASSERT_EQ(dm_.DeviceNum(), 4); 136 ASSERT_EQ(dm_.stage_num(), 2); 137 ASSERT_EQ(dm_.stage_id(), 1); 138 ASSERT_EQ(dm_.rank_index_in_stage(), 0); 139 ASSERT_EQ(dm_.GetDeviceListInThisStage().back(), 3); 140 141 RankList dev_list_0 = dm_.GetDeviceListByStageId(0); 142 RankList dev_list_1 = dm_.GetDeviceListByStageId(1); 143 ASSERT_EQ(dev_list_0.size(), 2); 144 ASSERT_EQ(dev_list_1.size(), 2); 145 } 146 } // namespace parallel 147 } // namespace mindspore 148