• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <vector>
17 #include "common/common_test.h"
18 #include "common/py_func_graph_fetcher.h"
19 #include "frontend/parallel/device_matrix.h"
20 
21 namespace mindspore {
22 namespace parallel {
23 
24 class TestDeviceMatrix : public UT::Common {
25  public:
TestDeviceMatrix()26   TestDeviceMatrix() {}
27 
SetUp()28   void SetUp() { UT::InitPythonPath(); }
29 
TearDown()30   virtual void TearDown() {}
31 };
32 
TEST_F(TestDeviceMatrix,Test2Dgroup_list)33 TEST_F(TestDeviceMatrix, Test2Dgroup_list) {
34   RankList dev_list = {0, 1, 2, 3, 4, 5};
35   Shape shape = {2, 3};
36 
37   DeviceMatrix arr(0, dev_list, shape);
38   std::vector<RankList> group_list;
39   if (arr.CreateGroupList() == Status::SUCCESS) group_list = arr.group_list();
40   std::vector<RankList> group_list_expect = {{0, 3}, {0, 1, 2}};
41   ASSERT_EQ(group_list, group_list_expect);
42 }
43 
TEST_F(TestDeviceMatrix,Test3Dgroup_list)44 TEST_F(TestDeviceMatrix, Test3Dgroup_list) {
45   RankList dev_list = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
46   Shape shape = {2, 2, 3};
47 
48   DeviceMatrix arr(5, dev_list, shape);
49   std::vector<RankList> group_list;
50   if (arr.CreateGroupList() == Status::SUCCESS) group_list = arr.group_list();
51   std::vector<RankList> group_list_expect = {{5, 11}, {2, 5}, {3, 4, 5}};
52   ASSERT_EQ(group_list, group_list_expect);
53 }
54 
TEST_F(TestDeviceMatrix,Test4DGetAlongDim)55 TEST_F(TestDeviceMatrix, Test4DGetAlongDim) {
56   RankList dev_list = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
57   Shape shape = {2, 1, 4, 2};
58 
59   DeviceMatrix arr(5, dev_list, shape);
60   std::vector<RankList> group_list;
61   if (arr.CreateGroupList() == Status::SUCCESS) group_list = arr.group_list();
62   std::vector<RankList> group_list_expect = {{5, 13}, {5}, {1, 3, 5, 7}, {4, 5}};
63   ASSERT_EQ(group_list, group_list_expect);
64 }
65 
TEST_F(TestDeviceMatrix,Test5DGetAlongDim)66 TEST_F(TestDeviceMatrix, Test5DGetAlongDim) {
67   RankList dev_list;
68   for (int i = 0; i < 144; i++) dev_list.push_back(i);
69   Shape shape = {3, 4, 2, 3, 2};
70 
71   DeviceMatrix arr(5, dev_list, shape);
72   std::vector<RankList> group_list;
73   if (arr.CreateGroupList() == Status::SUCCESS) group_list = arr.group_list();
74   std::vector<RankList> group_list_expect = {{5, 53, 101}, {5, 17, 29, 41}, {5, 11}, {1, 3, 5}, {4, 5}};
75   ASSERT_EQ(group_list, group_list_expect);
76 }
77 
TEST_F(TestDeviceMatrix,TestCornerCaseGetAlongDim)78 TEST_F(TestDeviceMatrix, TestCornerCaseGetAlongDim) {
79   // Shape does not match the number of devices
80   RankList dev_list = {0, 1, 2, 3, 4, 5, 6, 7, 8};
81   Shape shape = {2, 2, 2};
82 
83   EXPECT_THROW({ DeviceMatrix arr(3, dev_list, shape); }, std::runtime_error);
84 }
85 
TEST_F(TestDeviceMatrix,TestGetDeviceByTensorMapRandomOrderSliceOne)86 TEST_F(TestDeviceMatrix, TestGetDeviceByTensorMapRandomOrderSliceOne) {
87   RankList dev_list = {10, 3, 2, 9, 11, 100, 1, 0};
88   Shape tensor_map = {-1, 0};
89   RankList rank_list;
90   Shape shape = {4, 2};
91   DeviceMatrix arr(0, dev_list, shape);
92   arr.GetDevicesByTensorMap(tensor_map, &rank_list);
93   RankList rank_list_except = {3, 9, 100, 0};
94   ASSERT_EQ(rank_list, rank_list_except);
95 }
96 
TEST_F(TestDeviceMatrix,TestGetDeviceByTensorMapRandomOrderSliceTwo)97 TEST_F(TestDeviceMatrix, TestGetDeviceByTensorMapRandomOrderSliceTwo) {
98   RankList dev_list = {10, 3, 2, 9, 11, 100, 1, 0};
99   Shape tensor_map = {1, 0};
100   RankList rank_list;
101   Shape shape = {4, 2};
102   DeviceMatrix arr(0, dev_list, shape);
103   arr.GetDevicesByTensorMap(tensor_map, &rank_list);
104   RankList rank_list_except = {0};
105   ASSERT_EQ(rank_list, rank_list_except);
106 }
107 
TEST_F(TestDeviceMatrix,TestGetDeviceByTensorMapNoramalOrder2D)108 TEST_F(TestDeviceMatrix, TestGetDeviceByTensorMapNoramalOrder2D) {
109   RankList dev_list = {0, 1, 2, 3, 4, 5, 6, 7};
110   Shape tensor_map = {-1, 0};
111   RankList rank_list;
112   Shape shape = {4, 2};
113   DeviceMatrix arr(6, dev_list, shape);
114   arr.GetDevicesByTensorMap(tensor_map, &rank_list);
115   RankList rank_list_except = {0, 2, 4, 6};
116   ASSERT_EQ(rank_list, rank_list_except);
117 }
118 
TEST_F(TestDeviceMatrix,TestCornerCase2GetAlongDim)119 TEST_F(TestDeviceMatrix, TestCornerCase2GetAlongDim) {
120   // Rank is out of range
121   RankList dev_list = {0, 1, 2, 3, 4, 5, 6, 7};
122   Shape shape = {2, 2, 2};
123 
124   EXPECT_THROW({ DeviceMatrix arr(8, dev_list, shape); }, std::runtime_error);
125 }
126 
127 }  // namespace parallel
128 }  // namespace mindspore
129