• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <list>
17 #include "common/common_test.h"
18 #include "frontend/parallel/device.h"
19 #include "frontend/parallel/device_manager.h"
20 #include "frontend/parallel/group_manager.h"
21 
22 namespace mindspore {
23 namespace parallel {
24 
25 class TestDevice : public UT::Common {
26  public:
TestDevice()27   TestDevice() {}
28   void SetUp();
29   void TearDown();
30   Device dev_1;
31   Device dev_2;
32 };
33 
SetUp()34 void TestDevice::SetUp() {
35   std::string name = "#1";
36   dev_1 = Device(name, std::int32_t(1));
37   dev_2 = Device(std::int32_t(2));
38 }
39 
TearDown()40 void TestDevice::TearDown() {
41   // destroy resources
42 }
43 
TEST_F(TestDevice,test_device)44 TEST_F(TestDevice, test_device) {
45   std::string name = "#1";
46   int32_t dev1_rank = 1;
47   int32_t dev2_rank = 2;
48 
49   ASSERT_STREQ(dev_1.name().data(), name.data());
50   ASSERT_EQ(dev_1.rank(), dev1_rank);
51   ASSERT_EQ(dev_2.rank(), dev2_rank);
52 }
53 
54 // need to complete
55 class TestStage : public UT::Common {};
56 
57 class TestDeviceManager : public UT::Common {
58  public:
TestDeviceManager()59   TestDeviceManager() {}
60   void SetUp();
61   void TearDown();
62   DeviceManager dm_;
63 };
64 
SetUp()65 void TestDeviceManager::SetUp() { dm_ = DeviceManager::GetInstance(); }
66 
TearDown()67 void TestDeviceManager::TearDown() {
68   // destroy resources
69 }
70 
TEST_F(TestDeviceManager,test_dm_init_AND_get_device_list)71 TEST_F(TestDeviceManager, test_dm_init_AND_get_device_list) {
72   RankList dev_list;
73   RankList stage_map;
74   int32_t local_dev = 0;
75 
76   dev_list.push_back(5);
77   dev_list.push_back(3);
78   dev_list.push_back(1);
79   dev_list.push_back(0);
80 
81   stage_map.push_back(2);
82   stage_map.push_back(2);
83   ASSERT_EQ(dm_.Init(dev_list, local_dev, stage_map, "hccl"), Status::SUCCESS);
84 
85   ASSERT_EQ(dm_.DeviceNum(), 4);
86   ASSERT_EQ(dm_.stage_num(), (int32_t)(2));
87 
88   RankList dev_list_0 = dm_.GetDeviceListByStageId(0);
89   RankList dev_list_1 = dm_.GetDeviceListByStageId(1);
90   ASSERT_EQ(dev_list_0.size(), 2);
91   ASSERT_EQ(dev_list_1.size(), 2);
92 
93   RankList::iterator it = dev_list_0.begin();
94   ASSERT_EQ((*it), int32_t(5));
95   it++;
96   ASSERT_EQ((*it), int32_t(3));
97   it = dev_list_1.begin();
98   ASSERT_EQ((*it), int32_t(1));
99   it++;
100   ASSERT_EQ((*it), int32_t(0));
101 }
102 
TEST_F(TestDeviceManager,test_CreateNewDeviceByRank)103 TEST_F(TestDeviceManager, test_CreateNewDeviceByRank) {
104   Device one = dm_.CreateNewDeviceByRank(int32_t(3));
105   ASSERT_EQ(one.rank(), int32_t(3));
106 }
107 
TEST_F(TestDeviceManager,test_CreateDeviceListByRankList)108 TEST_F(TestDeviceManager, test_CreateDeviceListByRankList) {
109   std::vector<Device> dev_list;
110   RankList rlist;
111   rlist.push_back(int32_t(2));
112   rlist.push_back(int32_t(1));
113   dev_list = dm_.CreateDeviceListByRankList(rlist);
114 
115   std::vector<Device>::iterator it = dev_list.begin();
116   ASSERT_EQ(it->rank(), int32_t(2));
117   it++;
118   ASSERT_EQ(it->rank(), int32_t(1));
119 }
120 
TEST_F(TestDeviceManager,test_StageID)121 TEST_F(TestDeviceManager, test_StageID) {
122   RankList dev_list;
123   RankList stage_map;
124   int32_t local_dev = 2;
125 
126   dev_list.push_back(0);
127   dev_list.push_back(1);
128   dev_list.push_back(2);
129   dev_list.push_back(3);
130 
131   stage_map.push_back(2);
132   stage_map.push_back(2);
133   ASSERT_EQ(dm_.Init(dev_list, local_dev, stage_map, "hccl"), Status::SUCCESS);
134 
135   ASSERT_EQ(dm_.DeviceNum(), 4);
136   ASSERT_EQ(dm_.stage_num(), 2);
137   ASSERT_EQ(dm_.stage_id(), 1);
138   ASSERT_EQ(dm_.rank_index_in_stage(), 0);
139   ASSERT_EQ(dm_.GetDeviceListInThisStage().back(), 3);
140 
141   RankList dev_list_0 = dm_.GetDeviceListByStageId(0);
142   RankList dev_list_1 = dm_.GetDeviceListByStageId(1);
143   ASSERT_EQ(dev_list_0.size(), 2);
144   ASSERT_EQ(dev_list_1.size(), 2);
145 }
146 }  // namespace parallel
147 }  // namespace mindspore
148