1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <memory>
17 #include "common/common_test.h"
18 #include "runtime/hccl_adapter/all_to_all_v_calc_param.h"
19 #include "backend/session/anf_runtime_algorithm.h"
20 #include "mindspore/core/ir/dtype/type_id.h"
21
22 namespace mindspore::hccl {
23 class TestHcclAdapter : public UT::Common {
24 public:
TestHcclAdapter()25 TestHcclAdapter() {}
26
27 protected:
CreateAllToAllvNode(const FuncGraphPtr & graph,const std::vector<AnfNodePtr> inputs,const std::vector<int64_t> & send_rank_ids,const std::vector<int64_t> & recv_rank_ids)28 CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> inputs,
29 const std::vector<int64_t> &send_rank_ids, const std::vector<int64_t> &recv_rank_ids) {
30 MS_EXCEPTION_IF_NULL(graph);
31 std::vector<AnfNodePtr> all_to_all_v_input = {NewValueNode(std::make_shared<Primitive>(kAllToAllVOpName))};
32 all_to_all_v_input.insert(all_to_all_v_input.end(), inputs.begin(), inputs.end());
33 auto all_to_all_v = graph->NewCNode(all_to_all_v_input);
34 MS_EXCEPTION_IF_NULL(all_to_all_v);
35 AnfAlgo::SetNodeAttr(kAttrSendRankIds, MakeValue<std::vector<int64_t>>(send_rank_ids), all_to_all_v);
36 AnfAlgo::SetNodeAttr(kAttrRecvRankIds, MakeValue<std::vector<int64_t>>(recv_rank_ids), all_to_all_v);
37 AnfAlgo::SetNodeAttr(kAttrGroup, MakeValue<std::string>("default_group"), all_to_all_v);
38 return all_to_all_v;
39 }
40
SetOutputs(const CNodePtr & cnode,const std::vector<std::vector<size_t>> & shape,const std::vector<TypeId> & data_type)41 void SetOutputs(const CNodePtr &cnode, const std::vector<std::vector<size_t>> &shape,
42 const std::vector<TypeId> &data_type) {
43 AnfAlgo::SetOutputInferTypeAndShape(data_type, shape, cnode.get());
44 kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
45 builder.SetFusionType(kernel::FusionType::OPAQUE);
46 builder.SetProcessor(kernel::Processor::AICORE);
47 builder.SetKernelType(TBE_KERNEL);
48 builder.SetInputsFormat(std::vector<std::string>(cnode->size() - 1, format_));
49 builder.SetOutputsFormat(std::vector<std::string>(shape.size(), format_));
50 builder.SetInputsDeviceType(std::vector<TypeId>(cnode->size() - 1, type_));
51 builder.SetOutputsDeviceType(std::vector<TypeId>(shape.size(), type_));
52 cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
53 AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
54 }
55
CreateInputs(const FuncGraphPtr & graph,const std::vector<std::vector<size_t>> & shape,const std::vector<TypeId> & data_type)56 std::vector<AnfNodePtr> CreateInputs(const FuncGraphPtr &graph, const std::vector<std::vector<size_t>> &shape,
57 const std::vector<TypeId> &data_type) {
58 MS_EXCEPTION_IF_NULL(graph);
59 if (shape.size() != data_type.size()) {
60 return {};
61 }
62 std::vector<AnfNodePtr> res;
63 for (size_t i = 0; i < shape.size(); ++i) {
64 auto node = graph->NewCNode(std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>("AnyNameOp"))});
65 AnfAlgo::SetOutputInferTypeAndShape(std::vector<TypeId>{data_type[i]}, std::vector<std::vector<size_t>>{shape[i]},
66 node.get());
67 kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
68 builder.SetFusionType(kernel::FusionType::OPAQUE);
69 builder.SetProcessor(kernel::Processor::AICORE);
70 builder.SetKernelType(TBE_KERNEL);
71 builder.SetInputsFormat({format_});
72 builder.SetOutputsFormat({format_});
73 builder.SetInputsDeviceType({type_});
74 builder.SetOutputsDeviceType({type_});
75 node->set_kernel_info(std::make_shared<device::KernelInfo>());
76 AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), node.get());
77 res.emplace_back(node);
78 }
79 return res;
80 }
81
82 TypeId type_ = TypeId::kNumberTypeInt32;
83 std::string format_ = "NCHW";
84 };
85
86 /// Feature: AllToAllvCalcParam
87 /// Description: on 2p, send to rank 1, and recv nothing
88 /// Expectation: send count 0 1
89 /// send offset 0 0
90 /// recv count 0 0
91 /// recv offset 0 0
TEST_F(TestHcclAdapter,test_all_to_all_v_calc_param_2p_only_send)92 TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_2p_only_send) {
93 auto graph = std::make_shared<FuncGraph>();
94 ASSERT_TRUE(graph != nullptr);
95 uint32_t rank_size = 2;
96 std::vector<int64_t> send_rank_ids = {1};
97 std::vector<int64_t> recv_rank_ids = {};
98 auto alltoall = CreateAllToAllvNode(graph, CreateInputs(graph, {{1}}, {type_}), send_rank_ids, recv_rank_ids);
99 ASSERT_TRUE(alltoall != nullptr);
100 ASSERT_NO_THROW(SetOutputs(alltoall, {}, {}));
101 AllToAllvCalcParam calc(alltoall, rank_size);
102 ASSERT_NO_THROW(calc.CalcOpParam());
103 EXPECT_EQ(calc.GetSendCounts(), std::vector<int64_t>({0, 1}));
104 EXPECT_EQ(calc.GetSendDispls(), std::vector<int64_t>({0, 0}));
105 EXPECT_EQ(calc.GetRecvCounts(), std::vector<int64_t>({0, 0}));
106 EXPECT_EQ(calc.GetRecvDispls(), std::vector<int64_t>({0, 0}));
107 }
108
109 /// Feature: AllToAllvCalcParam
110 /// Description: on 2p, send nothing, and recv from rank 0 and rank 1
111 /// Expectation: send count 0 0
112 /// send offset 0 0
113 /// recv count 1 1
114 /// recv offset 0 128
TEST_F(TestHcclAdapter,test_all_to_all_v_calc_param_2p_only_recv)115 TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_2p_only_recv) {
116 auto graph = std::make_shared<FuncGraph>();
117 ASSERT_TRUE(graph != nullptr);
118 uint32_t rank_size = 2;
119 std::vector<int64_t> send_rank_ids = {};
120 std::vector<int64_t> recv_rank_ids = {0, 1};
121 auto alltoall = CreateAllToAllvNode(graph, CreateInputs(graph, {}, {}), send_rank_ids, recv_rank_ids);
122 ASSERT_TRUE(alltoall != nullptr);
123 ASSERT_NO_THROW(SetOutputs(alltoall, {{1}, {1}}, {type_, type_}));
124 AllToAllvCalcParam calc(alltoall, rank_size);
125 ASSERT_NO_THROW(calc.CalcOpParam());
126 EXPECT_EQ(calc.GetSendCounts(), std::vector<int64_t>({0, 0}));
127 EXPECT_EQ(calc.GetSendDispls(), std::vector<int64_t>({0, 0}));
128 EXPECT_EQ(calc.GetRecvCounts(), std::vector<int64_t>({1, 1}));
129 EXPECT_EQ(calc.GetRecvDispls(), std::vector<int64_t>({0, 128}));
130 }
131
132 /// Feature: AllToAllvCalcParam
133 /// Description: on 4p, send to rank1,2,3, and recv nothing
134 /// Expectation: send count 0 1 1 1
135 /// send offset 0 0 128 256
136 /// recv count 0 0 0 0
137 /// recv offset 0 0 0 0
TEST_F(TestHcclAdapter,test_all_to_all_v_calc_param_4p_only_send)138 TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_4p_only_send) {
139 auto graph = std::make_shared<FuncGraph>();
140 ASSERT_TRUE(graph != nullptr);
141 uint32_t rank_size = 4;
142 std::vector<int64_t> send_rank_ids = {1, 2, 3};
143 std::vector<int64_t> recv_rank_ids = {};
144 auto alltoall = CreateAllToAllvNode(graph, CreateInputs(graph, {{1}, {1}, {1}}, {type_, type_, type_}), send_rank_ids,
145 recv_rank_ids);
146 ASSERT_TRUE(alltoall != nullptr);
147 ASSERT_NO_THROW(SetOutputs(alltoall, {}, {}));
148 AllToAllvCalcParam calc(alltoall, rank_size);
149 ASSERT_NO_THROW(calc.CalcOpParam());
150 EXPECT_EQ(calc.GetSendCounts(), std::vector<int64_t>({0, 1, 1, 1}));
151 EXPECT_EQ(calc.GetSendDispls(), std::vector<int64_t>({0, 0, 128, 256}));
152 EXPECT_EQ(calc.GetRecvCounts(), std::vector<int64_t>({0, 0, 0, 0}));
153 EXPECT_EQ(calc.GetRecvDispls(), std::vector<int64_t>({0, 0, 0, 0}));
154 }
155
156 /// Feature: AllToAllvCalcParam
157 /// Description: on 4p, send to rank1,3, and recv nothing
158 /// Expectation: send count 0 1 0 1
159 /// send offset 0 0 128 128
160 /// recv count 0 0 0 0
161 /// recv offset 0 0 0 0
TEST_F(TestHcclAdapter,test_all_to_all_v_calc_param_4p_only_send_2)162 TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_4p_only_send_2) {
163 auto graph = std::make_shared<FuncGraph>();
164 ASSERT_TRUE(graph != nullptr);
165 uint32_t rank_size = 4;
166 std::vector<int64_t> send_rank_ids = {1, 3};
167 std::vector<int64_t> recv_rank_ids = {};
168 auto alltoall =
169 CreateAllToAllvNode(graph, CreateInputs(graph, {{1}, {1}}, {type_, type_}), send_rank_ids, recv_rank_ids);
170 ASSERT_TRUE(alltoall != nullptr);
171 ASSERT_NO_THROW(SetOutputs(alltoall, {}, {}));
172 AllToAllvCalcParam calc(alltoall, rank_size);
173 ASSERT_NO_THROW(calc.CalcOpParam());
174 EXPECT_EQ(calc.GetSendCounts(), std::vector<int64_t>({0, 1, 0, 1}));
175 EXPECT_EQ(calc.GetSendDispls(), std::vector<int64_t>({0, 0, 128, 128}));
176 EXPECT_EQ(calc.GetRecvCounts(), std::vector<int64_t>({0, 0, 0, 0}));
177 EXPECT_EQ(calc.GetRecvDispls(), std::vector<int64_t>({0, 0, 0, 0}));
178 }
179
180 /// Feature: AllToAllvCalcParam
181 /// Description: on 2p, send to rank1, and recv from rank1
182 /// Expectation: send count 0 1
183 /// send offset 0 0
184 /// recv count 0 1
185 /// recv offset 0 0
TEST_F(TestHcclAdapter,test_all_to_all_v_calc_param_2p_exchange)186 TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_2p_exchange) {
187 auto graph = std::make_shared<FuncGraph>();
188 ASSERT_TRUE(graph != nullptr);
189 uint32_t rank_size = 2;
190 std::vector<int64_t> send_rank_ids = {1};
191 std::vector<int64_t> recv_rank_ids = {1};
192 auto alltoall = CreateAllToAllvNode(graph, CreateInputs(graph, {{1}}, {type_}), send_rank_ids, recv_rank_ids);
193 ASSERT_TRUE(alltoall != nullptr);
194 ASSERT_NO_THROW(SetOutputs(alltoall, {{1}}, {type_}));
195 AllToAllvCalcParam calc(alltoall, rank_size);
196 ASSERT_NO_THROW(calc.CalcOpParam());
197 EXPECT_EQ(calc.GetSendCounts(), std::vector<int64_t>({0, 1}));
198 EXPECT_EQ(calc.GetSendDispls(), std::vector<int64_t>({0, 0}));
199 EXPECT_EQ(calc.GetRecvCounts(), std::vector<int64_t>({0, 1}));
200 EXPECT_EQ(calc.GetRecvDispls(), std::vector<int64_t>({0, 0}));
201 }
202
203 /// Feature: AllToAllvCalcParam
204 /// Description: on 2p, send to rank0, and recv from rank0
205 /// Expectation: send count 1 0
206 /// send offset 0 128
207 /// recv count 1 0
208 /// recv offset 0 128
TEST_F(TestHcclAdapter,test_all_to_all_v_calc_param_2p_send_to_self)209 TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_2p_send_to_self) {
210 auto graph = std::make_shared<FuncGraph>();
211 ASSERT_TRUE(graph != nullptr);
212 uint32_t rank_size = 2;
213 std::vector<int64_t> send_rank_ids = {0};
214 std::vector<int64_t> recv_rank_ids = {0};
215 auto alltoall = CreateAllToAllvNode(graph, CreateInputs(graph, {{1}}, {type_}), send_rank_ids, recv_rank_ids);
216 ASSERT_TRUE(alltoall != nullptr);
217 ASSERT_NO_THROW(SetOutputs(alltoall, {{1}}, {type_}));
218 AllToAllvCalcParam calc(alltoall, rank_size);
219 ASSERT_NO_THROW(calc.CalcOpParam());
220 EXPECT_EQ(calc.GetSendCounts(), std::vector<int64_t>({1, 0}));
221 EXPECT_EQ(calc.GetSendDispls(), std::vector<int64_t>({0, 128}));
222 EXPECT_EQ(calc.GetRecvCounts(), std::vector<int64_t>({1, 0}));
223 EXPECT_EQ(calc.GetRecvDispls(), std::vector<int64_t>({0, 128}));
224 }
225
226 /// Feature: AllToAllvCalcParam
227 /// Description: on 4p, send to rank0123, and recv from rank0123
228 /// Expectation: send count 1 1 1 1
229 /// send offset 0 128 256 384
230 /// recv count 1 1 1 1
231 /// recv offset 0 128 256 384
TEST_F(TestHcclAdapter,test_all_to_all_v_calc_param_4p_all_to_all)232 TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_4p_all_to_all) {
233 auto graph = std::make_shared<FuncGraph>();
234 ASSERT_TRUE(graph != nullptr);
235 uint32_t rank_size = 4;
236 std::vector<int64_t> send_rank_ids = {0, 1, 2, 3};
237 std::vector<int64_t> recv_rank_ids = {0, 1, 2, 3};
238 auto alltoall = CreateAllToAllvNode(graph, CreateInputs(graph, {{1}, {1}, {1}, {1}}, {type_, type_, type_, type_}),
239 send_rank_ids, recv_rank_ids);
240 ASSERT_TRUE(alltoall != nullptr);
241 ASSERT_NO_THROW(SetOutputs(alltoall, {{1}, {1}, {1}, {1}}, {type_, type_, type_, type_}));
242 AllToAllvCalcParam calc(alltoall, rank_size);
243 ASSERT_NO_THROW(calc.CalcOpParam());
244 EXPECT_EQ(calc.GetSendCounts(), std::vector<int64_t>({1, 1, 1, 1}));
245 EXPECT_EQ(calc.GetSendDispls(), std::vector<int64_t>({0, 128, 256, 384}));
246 EXPECT_EQ(calc.GetRecvCounts(), std::vector<int64_t>({1, 1, 1, 1}));
247 EXPECT_EQ(calc.GetRecvDispls(), std::vector<int64_t>({0, 128, 256, 384}));
248 }
249
250 /// Feature: AllToAllvCalcParam
251 /// Description: on 4p, send to rank0123, and recv from rank0123, but recv order is wrong
252 /// Expectation: send count 1 1 1 1
253 /// send offset 0 128 256 384
254 /// recv count 1 1 1 1
255 /// recv offset 256 128 384 0
TEST_F(TestHcclAdapter,test_all_to_all_v_calc_param_4p_all_in_all_in_wrong_order)256 TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_4p_all_in_all_in_wrong_order) {
257 auto graph = std::make_shared<FuncGraph>();
258 ASSERT_TRUE(graph != nullptr);
259 uint32_t rank_size = 4;
260 std::vector<int64_t> send_rank_ids = {0, 1, 2, 3};
261 std::vector<int64_t> recv_rank_ids = {3, 1, 0, 2};
262 auto alltoall = CreateAllToAllvNode(graph, CreateInputs(graph, {{1}, {1}, {1}, {1}}, {type_, type_, type_, type_}),
263 send_rank_ids, recv_rank_ids);
264 ASSERT_TRUE(alltoall != nullptr);
265 ASSERT_NO_THROW(SetOutputs(alltoall, {{1}, {1}, {1}, {1}}, {type_, type_, type_, type_}));
266 AllToAllvCalcParam calc(alltoall, rank_size);
267 ASSERT_NO_THROW(calc.CalcOpParam());
268 EXPECT_EQ(calc.GetSendCounts(), std::vector<int64_t>({1, 1, 1, 1}));
269 EXPECT_EQ(calc.GetSendDispls(), std::vector<int64_t>({0, 128, 256, 384}));
270 EXPECT_EQ(calc.GetRecvCounts(), std::vector<int64_t>({1, 1, 1, 1}));
271 EXPECT_EQ(calc.GetRecvDispls(), std::vector<int64_t>({256, 128, 384, 0}));
272 }
273
274 /// Feature: AllToAllvCalcParam
275 /// Description: on 4p, send to rank123, and recv from nothing, but send order is wrong
276 /// Expectation: send count 0 1 1 1
277 /// send offset 0 128 256 0
278 /// recv count 0 0 0 0
279 /// recv offset 0 0 0 0
TEST_F(TestHcclAdapter,test_all_to_all_v_calc_param_4p_only_send_in_wrong_order)280 TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_4p_only_send_in_wrong_order) {
281 auto graph = std::make_shared<FuncGraph>();
282 ASSERT_TRUE(graph != nullptr);
283 uint32_t rank_size = 4;
284 std::vector<int64_t> send_rank_ids = {3, 1, 2};
285 std::vector<int64_t> recv_rank_ids = {};
286 auto alltoall = CreateAllToAllvNode(graph, CreateInputs(graph, {{1}, {1}, {1}}, {type_, type_, type_}), send_rank_ids,
287 recv_rank_ids);
288 ASSERT_TRUE(alltoall != nullptr);
289 ASSERT_NO_THROW(SetOutputs(alltoall, {}, {}));
290 AllToAllvCalcParam calc(alltoall, rank_size);
291 ASSERT_NO_THROW(calc.CalcOpParam());
292 EXPECT_EQ(calc.GetSendCounts(), std::vector<int64_t>({0, 1, 1, 1}));
293 EXPECT_EQ(calc.GetSendDispls(), std::vector<int64_t>({0, 128, 256, 0}));
294 EXPECT_EQ(calc.GetRecvCounts(), std::vector<int64_t>({0, 0, 0, 0}));
295 EXPECT_EQ(calc.GetRecvDispls(), std::vector<int64_t>({0, 0, 0, 0}));
296 }
297
298 /// Feature: AllToAllvCalcParam
299 /// Description: on 2p, rank id over valid range
300 /// Expectation: throw exception
TEST_F(TestHcclAdapter,test_all_to_all_v_calc_param_2p_invalid_rank_id)301 TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_2p_invalid_rank_id) {
302 auto graph = std::make_shared<FuncGraph>();
303 ASSERT_TRUE(graph != nullptr);
304 uint32_t rank_size = 2;
305 std::vector<int64_t> send_rank_ids = {};
306 std::vector<int64_t> recv_rank_ids = {0, 2};
307 auto alltoall = CreateAllToAllvNode(graph, CreateInputs(graph, {}, {}), send_rank_ids, recv_rank_ids);
308 ASSERT_TRUE(alltoall != nullptr);
309 ASSERT_NO_THROW(SetOutputs(alltoall, {{1}, {1}}, {type_, type_}));
310 AllToAllvCalcParam calc(alltoall, rank_size);
311 ASSERT_ANY_THROW(calc.CalcOpParam());
312 }
313
314 /// Feature: AllToAllvCalcParam
315 /// Description: on 2p, has 2 outputs but only 1 recv_rank_ids is set
316 /// Expectation: throw exception
TEST_F(TestHcclAdapter,test_all_to_all_v_calc_param_2p_invalid_rank_id_2)317 TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_2p_invalid_rank_id_2) {
318 auto graph = std::make_shared<FuncGraph>();
319 ASSERT_TRUE(graph != nullptr);
320 uint32_t rank_size = 2;
321 std::vector<int64_t> send_rank_ids = {};
322 std::vector<int64_t> recv_rank_ids = {0};
323 auto alltoall = CreateAllToAllvNode(graph, CreateInputs(graph, {}, {}), send_rank_ids, recv_rank_ids);
324 ASSERT_TRUE(alltoall != nullptr);
325 ASSERT_NO_THROW(SetOutputs(alltoall, {{1}, {1}}, {type_, type_}));
326 AllToAllvCalcParam calc(alltoall, rank_size);
327 ASSERT_ANY_THROW(calc.CalcOpParam());
328 }
329
330 /// Feature: AllToAllvCalcParam
331 /// Description: on 2p, rank id over valid range
332 /// Expectation: throw exception
TEST_F(TestHcclAdapter,test_all_to_all_v_calc_param_2p_wrong_order_and_invalid_rank_id)333 TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_2p_wrong_order_and_invalid_rank_id) {
334 auto graph = std::make_shared<FuncGraph>();
335 ASSERT_TRUE(graph != nullptr);
336 uint32_t rank_size = 2;
337 std::vector<int64_t> send_rank_ids = {};
338 std::vector<int64_t> recv_rank_ids = {2, 0};
339 auto alltoall = CreateAllToAllvNode(graph, CreateInputs(graph, {}, {}), send_rank_ids, recv_rank_ids);
340 ASSERT_TRUE(alltoall != nullptr);
341 ASSERT_NO_THROW(SetOutputs(alltoall, {{1}, {1}}, {type_, type_}));
342 AllToAllvCalcParam calc(alltoall, rank_size);
343 ASSERT_ANY_THROW(calc.CalcOpParam());
344 }
345 } // namespace mindspore::hccl
346