1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include <list> 17 #include "frontend/parallel/device_manager.h" 18 #include "common/common_test.h" 19 #include "frontend/parallel/device.h" 20 #include "frontend/parallel/group_manager.h" 21 22 namespace mindspore { 23 namespace parallel { 24 25 extern DeviceManagerPtr g_device_manager; 26 27 class TestGroup : public UT::Common { 28 public: 29 TestGroup() {} 30 void SetUp(); 31 void TearDown(); 32 Status Init(); 33 34 Group gp; 35 }; 36 37 void TestGroup::SetUp() { gp = Group(); } 38 39 void TestGroup::TearDown() { 40 // destroy resources 41 } 42 43 Status TestGroup::Init() { 44 std::string gname = "1-2"; 45 std::vector<Device> dev_list; 46 Device one = Device(int32_t(1)); 47 dev_list.push_back(one); 48 Device two = Device(int32_t(2)); 49 dev_list.push_back(two); 50 51 return gp.Init(gname, dev_list); 52 } 53 54 TEST_F(TestGroup, test_Init) { ASSERT_EQ(Init(), Status::SUCCESS); } 55 56 TEST_F(TestGroup, test_GetDevicesList) { 57 Init(); 58 std::vector<Device> res_dev_list = gp.GetDevicesList(); 59 std::vector<Device>::iterator it = res_dev_list.begin(); 60 ASSERT_EQ(it->rank(), int32_t(1)); 61 it++; 62 ASSERT_EQ(it->rank(), int32_t(2)); 63 } 64 65 TEST_F(TestGroup, test_IsInThisGroup) { 66 Init(); 67 ASSERT_TRUE(gp.IsInThisGroup(int32_t(1))); 68 ASSERT_TRUE(gp.IsInThisGroup(int32_t(2))); 69 70 ASSERT_FALSE(gp.IsInThisGroup(int32_t(3))); 71 } 72 73 class TestGroupManager : public UT::Common { 74 public: 75 TestGroupManager() {} 76 void SetUp(); 77 void TearDown(); 78 Status Init(Group** gp_ptr); 79 80 GroupManager gm; 81 }; 82 83 void TestGroupManager::SetUp() { gm = GroupManager(); } 84 85 void TestGroupManager::TearDown() { 86 // destroy resources 87 } 88 89 Status TestGroupManager::Init(Group** gp_ptr) { 90 std::string gname = "1-2"; 91 std::vector<Device> dev_list; 92 Device one = Device(int32_t(1)); 93 dev_list.push_back(one); 94 Device two = Device(int32_t(2)); 95 dev_list.push_back(two); 96 97 return gm.CreateGroup(gname, dev_list, *gp_ptr); 98 } 99 100 TEST_F(TestGroupManager, test_CreateGroup) { 101 // testing for creating a group 102 Group* gp_ptr = new Group(); 103 ASSERT_EQ(Init(&gp_ptr), Status::SUCCESS); 104 105 std::vector<Device> res_dev_list = gp_ptr->GetDevicesList(); 106 std::vector<Device>::iterator it = res_dev_list.begin(); 107 ASSERT_EQ(it->rank(), int32_t(1)); 108 it++; 109 ASSERT_EQ(it->rank(), int32_t(2)); 110 delete gp_ptr; 111 112 // testing for creating a group with an existing group name 113 std::vector<Device> dev_list2; 114 Device three = Device(int32_t(3)); 115 dev_list2.push_back(three); 116 Device four = Device(int32_t(4)); 117 dev_list2.push_back(four); 118 gp_ptr = new Group(); 119 ASSERT_EQ(gm.CreateGroup("1-2", dev_list2, gp_ptr), Status::SUCCESS); 120 121 ASSERT_STREQ(gp_ptr->name().data(), "1-2"); 122 std::vector<Device> res_dev_list2 = gp_ptr->GetDevicesList(); 123 std::vector<Device>::iterator it2 = res_dev_list2.begin(); 124 ASSERT_EQ(it2->rank(), int32_t(1)); 125 it2++; 126 ASSERT_EQ(it2->rank(), int32_t(2)); 127 delete gp_ptr; 128 gp_ptr = nullptr; 129 } 130 131 TEST_F(TestGroupManager, test_FindGroup) { 132 std::string gname = "1-2"; 133 Group* gp_ptr = new Group(); 134 Group* gp_ptr2 = new Group(); 135 ASSERT_EQ(Init(&gp_ptr), Status::SUCCESS); 136 137 ASSERT_EQ(gm.FindGroup(gname, &gp_ptr2), Status::SUCCESS); 138 139 std::vector<Device> res_dev_list = gp_ptr2->GetDevicesList(); 140 std::vector<Device>::iterator it = res_dev_list.begin(); 141 ASSERT_EQ(it->rank(), int32_t(1)); 142 it++; 143 ASSERT_EQ(it->rank(), int32_t(2)); 144 delete gp_ptr; 145 gp_ptr = nullptr; 146 147 std::string gname2 = "3-4"; 148 gp_ptr2 = new Group(); 149 ASSERT_EQ(gm.FindGroup(gname2, &gp_ptr2), Status::FAILED); 150 delete gp_ptr2; 151 gp_ptr2 = nullptr; 152 } 153 154 } // namespace parallel 155 } // namespace mindspore 156