1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <algorithm>
17 #include <string>
18 #include <map>
19 #include <memory>
20 #include <unordered_set>
21
22 #include "common/common.h"
23 #include "gtest/gtest.h"
24 #include "minddata/dataset/util/status.h"
25 #include "minddata/dataset/engine/gnn/node.h"
26 #include "minddata/dataset/engine/gnn/graph_data_impl.h"
27 #include "minddata/dataset/engine/gnn/graph_loader.h"
28
29 using namespace mindspore::dataset;
30 using namespace mindspore::dataset::gnn;
31
32 #define print_int_vec(_i, _str) \
33 do { \
34 std::stringstream ss; \
35 std::copy(_i.begin(), _i.end(), std::ostream_iterator<int>(ss, " ")); \
36 MS_LOG(INFO) << _str << " " << ss.str(); \
37 } while (false)
38
39 class MindDataTestGNNGraph : public UT::Common {
40 protected:
41 MindDataTestGNNGraph() = default;
42
43 using NumNeighborsMap = std::map<NodeIdType, uint32_t>;
44 using NodeNeighborsMap = std::map<NodeIdType, NumNeighborsMap>;
ParsingNeighbors(const std::shared_ptr<Tensor> & neighbors,NodeNeighborsMap & node_neighbors)45 void ParsingNeighbors(const std::shared_ptr<Tensor> &neighbors, NodeNeighborsMap &node_neighbors) {
46 auto shape_vec = neighbors->shape().AsVector();
47 uint32_t num_members = 1;
48 for (size_t i = 1; i < shape_vec.size(); ++i) {
49 num_members *= shape_vec[i];
50 }
51 uint32_t index = 0;
52 NodeIdType src_node = 0;
53 for (auto node_itr = neighbors->begin<NodeIdType>(); node_itr != neighbors->end<NodeIdType>();
54 ++node_itr, ++index) {
55 if (index % num_members == 0) {
56 src_node = *node_itr;
57 continue;
58 }
59 auto src_node_itr = node_neighbors.find(src_node);
60 if (src_node_itr == node_neighbors.end()) {
61 node_neighbors[src_node] = {{*node_itr, 1}};
62 } else {
63 auto nei_itr = src_node_itr->second.find(*node_itr);
64 if (nei_itr == src_node_itr->second.end()) {
65 src_node_itr->second[*node_itr] = 1;
66 } else {
67 src_node_itr->second[*node_itr] += 1;
68 }
69 }
70 }
71 }
72
CheckNeighborsRatio(const NumNeighborsMap & number_neighbors,const std::vector<WeightType> & weights,float deviation_ratio=0.2)73 void CheckNeighborsRatio(const NumNeighborsMap &number_neighbors, const std::vector<WeightType> &weights,
74 float deviation_ratio = 0.2) {
75 EXPECT_EQ(number_neighbors.size(), weights.size());
76 int index = 0;
77 uint32_t pre_num = 0;
78 WeightType pre_weight = 1;
79 for (auto neighbor : number_neighbors) {
80 if (pre_num != 0) {
81 float target_ratio = static_cast<float>(pre_weight) / static_cast<float>(weights[index]);
82 float current_ratio = static_cast<float>(pre_num) / static_cast<float>(neighbor.second);
83 float target_upper = target_ratio * (1 + deviation_ratio);
84 float target_lower = target_ratio * (1 - deviation_ratio);
85 MS_LOG(INFO) << "current_ratio:" << std::to_string(current_ratio)
86 << " target_upper:" << std::to_string(target_upper)
87 << " target_lower:" << std::to_string(target_lower);
88 EXPECT_LE(current_ratio, target_upper);
89 EXPECT_GE(current_ratio, target_lower);
90 }
91 pre_num = neighbor.second;
92 pre_weight = weights[index];
93 ++index;
94 }
95 }
96 };
97
TEST_F(MindDataTestGNNGraph,TestGetEdgesFromNodes)98 TEST_F(MindDataTestGNNGraph, TestGetEdgesFromNodes) {
99 std::string path = "data/mindrecord/testGraphData/testdata";
100 GraphDataImpl graph(path, 1);
101 Status s = graph.Init();
102 EXPECT_TRUE(s.IsOk());
103
104 std::vector<std::pair<NodeIdType, NodeIdType>> src_dst_list = {{101, 201}, {103, 207}, {108, 208},
105 {110, 201}, {204, 105}, {208, 108}};
106 std::shared_ptr<Tensor> edges;
107 s = graph.GetEdgesFromNodes(src_dst_list, &edges);
108
109 EXPECT_TRUE(s.IsOk());
110 EXPECT_TRUE(edges->ToString() == "Tensor (shape: <6>, Type: int32)\n[1,9,17,19,31,37]");
111 }
112
TEST_F(MindDataTestGNNGraph,TestGetAllNeighbors)113 TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
114 std::string path = "data/mindrecord/testGraphData/testdata";
115 GraphDataImpl graph(path, 1);
116 Status s = graph.Init();
117 EXPECT_TRUE(s.IsOk());
118
119 MetaInfo meta_info;
120 s = graph.GetMetaInfo(&meta_info);
121 EXPECT_TRUE(s.IsOk());
122 EXPECT_TRUE(meta_info.node_type.size() == 2);
123
124 std::shared_ptr<Tensor> nodes;
125 s = graph.GetAllNodes(meta_info.node_type[0], &nodes);
126 EXPECT_TRUE(s.IsOk());
127 std::vector<NodeIdType> node_list;
128 for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) {
129 node_list.push_back(*itr);
130 if (node_list.size() >= 10) {
131 break;
132 }
133 }
134 std::shared_ptr<Tensor> neighbors;
135 s = graph.GetAllNeighbors(node_list, meta_info.node_type[1], OutputFormat::kNormal, &neighbors);
136 EXPECT_TRUE(s.IsOk());
137 EXPECT_TRUE(neighbors->shape().ToString() == "<10,6>");
138 TensorRow features;
139 s = graph.GetNodeFeature(nodes, meta_info.node_feature_type, &features);
140 EXPECT_TRUE(s.IsOk());
141 EXPECT_TRUE(features.size() == 4);
142 EXPECT_TRUE(features[0]->shape().ToString() == "<10,5>");
143 EXPECT_TRUE(features[0]->ToString() ==
144 "Tensor (shape: <10,5>, Type: int32)\n"
145 "[[0,1,0,0,0],[1,0,0,0,1],[0,0,1,1,0],[0,0,0,0,0],[1,1,0,1,0],[0,0,0,0,1],[0,1,0,0,0],[0,0,0,1,1],[0,1,1,"
146 "0,0],[0,1,0,1,0]]");
147 EXPECT_TRUE(features[1]->shape().ToString() == "<10>");
148 EXPECT_TRUE(features[1]->ToString() ==
149 "Tensor (shape: <10>, Type: float32)\n[0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1]");
150 EXPECT_TRUE(features[2]->shape().ToString() == "<10>");
151 EXPECT_TRUE(features[2]->ToString() == "Tensor (shape: <10>, Type: int32)\n[1,2,3,1,4,3,5,3,5,4]");
152 }
153
TEST_F(MindDataTestGNNGraph,TestGetAllNeighborsSpecialFormat)154 TEST_F(MindDataTestGNNGraph, TestGetAllNeighborsSpecialFormat) {
155 std::string path = "data/mindrecord/testGraphData/testdata";
156 GraphDataImpl graph(path, 1);
157 Status s = graph.Init();
158 EXPECT_TRUE(s.IsOk());
159
160 MetaInfo meta_info;
161 s = graph.GetMetaInfo(&meta_info);
162 EXPECT_TRUE(s.IsOk());
163 EXPECT_TRUE(meta_info.node_type.size() == 2);
164
165 std::shared_ptr<Tensor> nodes;
166 s = graph.GetAllNodes(meta_info.node_type[0], &nodes);
167 EXPECT_TRUE(s.IsOk());
168 std::vector<NodeIdType> node_list;
169 for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) {
170 node_list.push_back(*itr);
171 if (node_list.size() >= 10) {
172 break;
173 }
174 }
175 // Check COO format
176 std::shared_ptr<Tensor> neighbors_coo;
177 s = graph.GetAllNeighbors(node_list, meta_info.node_type[1], OutputFormat::kCoo, &neighbors_coo);
178 EXPECT_TRUE(s.IsOk());
179 EXPECT_TRUE(neighbors_coo->shape().ToString() == "<20,2>");
180 EXPECT_TRUE(neighbors_coo->ToString() ==
181 "Tensor (shape: <20,2>, Type: int32)\n"
182 "[[101,201],[101,205],[101,206],[102,201],[102,202],[103,203],[103,205],[103,206],[103,207],[103,208],"
183 "[105,204],[106,202],[106,203],[107,201],[107,203],[107,207],[108,208],[109,210],[110,201],[110,210]]");
184 // Check CSR format
185 std::shared_ptr<Tensor> neighbors_csr;
186 s = graph.GetAllNeighbors(node_list, meta_info.node_type[1], OutputFormat::kCsr, &neighbors_csr);
187 EXPECT_TRUE(s.IsOk());
188 EXPECT_TRUE(neighbors_csr->shape().ToString() == "<30>");
189 EXPECT_TRUE(
190 neighbors_csr->ToString() ==
191 "Tensor (shape: <30>, Type: int32)\n"
192 "[0,3,5,10,10,11,13,16,17,18,201,205,206,201,202,203,205,206,207,208,204,202,203,201,203,207,208,210,201,210]");
193 }
194
TEST_F(MindDataTestGNNGraph,TestGetSampledNeighbors)195 TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) {
196 std::string path = "data/mindrecord/testGraphData/testdata";
197 GraphDataImpl graph(path, 1);
198 Status s = graph.Init();
199 EXPECT_TRUE(s.IsOk());
200
201 MetaInfo meta_info;
202 s = graph.GetMetaInfo(&meta_info);
203 EXPECT_TRUE(s.IsOk());
204 EXPECT_TRUE(meta_info.node_type.size() == 2);
205
206 std::shared_ptr<Tensor> edges;
207 s = graph.GetAllEdges(meta_info.edge_type[0], &edges);
208 EXPECT_TRUE(s.IsOk());
209 std::vector<EdgeIdType> edge_list;
210 edge_list.resize(edges->Size());
211 std::transform(edges->begin<EdgeIdType>(), edges->end<EdgeIdType>(), edge_list.begin(),
212 [](const EdgeIdType edge) { return edge; });
213
214 TensorRow edge_features;
215 s = graph.GetEdgeFeature(edges, meta_info.edge_feature_type, &edge_features);
216 EXPECT_TRUE(s.IsOk());
217 EXPECT_TRUE(edge_features[0]->ToString() ==
218 "Tensor (shape: <40>, Type: int32)\n"
219 "[0,1,0,0,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0]");
220 EXPECT_TRUE(edge_features[1]->ToString() ==
221 "Tensor (shape: <40>, Type: float32)\n"
222 "[0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1,1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,1.9,2,2.1,2.2,2.3,2.4,2.5,2.6,2."
223 "7,2.8,2.9,3,3.1,3.2,3.3,3.4,3.5,3.6,3.7,3.8,3.9,4]");
224
225 std::shared_ptr<Tensor> nodes;
226 s = graph.GetNodesFromEdges(edge_list, &nodes);
227 EXPECT_TRUE(s.IsOk());
228 std::unordered_set<NodeIdType> node_set;
229 std::vector<NodeIdType> node_list;
230 int index = 0;
231 for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) {
232 index++;
233 if (index % 2 == 0) {
234 continue;
235 }
236 node_set.emplace(*itr);
237 if (node_set.size() >= 5) {
238 break;
239 }
240 }
241 node_list.resize(node_set.size());
242 std::transform(node_set.begin(), node_set.end(), node_list.begin(), [](const NodeIdType node) { return node; });
243
244 std::shared_ptr<Tensor> neighbors;
245 {
246 MS_LOG(INFO) << "Test random sampling.";
247 NodeNeighborsMap number_neighbors;
248 int count = 0;
249 while (count < 1000) {
250 neighbors.reset();
251 s = graph.GetSampledNeighbors(node_list, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors);
252 EXPECT_TRUE(s.IsOk());
253 EXPECT_TRUE(neighbors->shape().ToString() == "<5,11>");
254 ParsingNeighbors(neighbors, number_neighbors);
255 ++count;
256 }
257 CheckNeighborsRatio(number_neighbors[103], {1, 1, 1, 1, 1});
258 }
259
260 {
261 MS_LOG(INFO) << "Test edge weight sampling.";
262 NodeNeighborsMap number_neighbors;
263 int count = 0;
264 while (count < 1000) {
265 neighbors.reset();
266 s =
267 graph.GetSampledNeighbors(node_list, {10}, {meta_info.node_type[1]}, SamplingStrategy::kEdgeWeight, &neighbors);
268 EXPECT_TRUE(s.IsOk());
269 EXPECT_TRUE(neighbors->shape().ToString() == "<5,11>");
270 ParsingNeighbors(neighbors, number_neighbors);
271 ++count;
272 }
273 CheckNeighborsRatio(number_neighbors[103], {3, 5, 6, 7, 8});
274 }
275
276 neighbors.reset();
277 s = graph.GetSampledNeighbors(node_list, {2, 3}, {meta_info.node_type[1], meta_info.node_type[0]},
278 SamplingStrategy::kRandom, &neighbors);
279 EXPECT_TRUE(s.IsOk());
280 EXPECT_TRUE(neighbors->shape().ToString() == "<5,9>");
281
282 neighbors.reset();
283 s = graph.GetSampledNeighbors(node_list, {2, 3, 4},
284 {meta_info.node_type[1], meta_info.node_type[0], meta_info.node_type[1]},
285 SamplingStrategy::kRandom, &neighbors);
286 EXPECT_TRUE(s.IsOk());
287 EXPECT_TRUE(neighbors->shape().ToString() == "<5,33>");
288
289 neighbors.reset();
290 s = graph.GetSampledNeighbors({}, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors);
291 EXPECT_TRUE(s.ToString().find("Input node_list is empty.") != std::string::npos);
292
293 neighbors.reset();
294 s = graph.GetSampledNeighbors({-1, 1}, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors);
295 EXPECT_TRUE(s.ToString().find("Invalid node id") != std::string::npos);
296
297 neighbors.reset();
298 s = graph.GetSampledNeighbors(node_list, {2, 50}, {meta_info.node_type[0], meta_info.node_type[1]},
299 SamplingStrategy::kRandom, &neighbors);
300 EXPECT_TRUE(s.ToString().find("Wrong samples number") != std::string::npos);
301
302 neighbors.reset();
303 s = graph.GetSampledNeighbors(node_list, {2}, {5}, SamplingStrategy::kRandom, &neighbors);
304 EXPECT_TRUE(s.ToString().find("Invalid neighbor type") != std::string::npos);
305
306 neighbors.reset();
307 s = graph.GetSampledNeighbors(node_list, {2, 3, 4}, {meta_info.node_type[1], meta_info.node_type[0]},
308 SamplingStrategy::kRandom, &neighbors);
309 EXPECT_TRUE(s.ToString().find("The sizes of neighbor_nums and neighbor_types are inconsistent.") !=
310 std::string::npos);
311
312 neighbors.reset();
313 s = graph.GetSampledNeighbors({301}, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors);
314 EXPECT_TRUE(s.ToString().find("Invalid node id:301") != std::string::npos);
315 }
316
TEST_F(MindDataTestGNNGraph,TestGetNegSampledNeighbors)317 TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) {
318 std::string path = "data/mindrecord/testGraphData/testdata";
319 GraphDataImpl graph(path, 1);
320 Status s = graph.Init();
321 EXPECT_TRUE(s.IsOk());
322
323 MetaInfo meta_info;
324 s = graph.GetMetaInfo(&meta_info);
325 EXPECT_TRUE(s.IsOk());
326 EXPECT_TRUE(meta_info.node_type.size() == 2);
327
328 std::shared_ptr<Tensor> nodes;
329 s = graph.GetAllNodes(meta_info.node_type[0], &nodes);
330 EXPECT_TRUE(s.IsOk());
331 std::vector<NodeIdType> node_list;
332 for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) {
333 node_list.push_back(*itr);
334 if (node_list.size() >= 10) {
335 break;
336 }
337 }
338 std::shared_ptr<Tensor> neg_neighbors;
339 s = graph.GetNegSampledNeighbors(node_list, 3, meta_info.node_type[1], &neg_neighbors);
340 EXPECT_TRUE(s.IsOk());
341 EXPECT_TRUE(neg_neighbors->shape().ToString() == "<10,4>");
342
343 neg_neighbors.reset();
344 s = graph.GetNegSampledNeighbors({}, 3, meta_info.node_type[1], &neg_neighbors);
345 EXPECT_TRUE(s.ToString().find("Input node_list is empty.") != std::string::npos);
346
347 neg_neighbors.reset();
348 s = graph.GetNegSampledNeighbors({-1, 1}, 3, meta_info.node_type[1], &neg_neighbors);
349 EXPECT_TRUE(s.ToString().find("Invalid node id") != std::string::npos);
350
351 neg_neighbors.reset();
352 s = graph.GetNegSampledNeighbors(node_list, 50, meta_info.node_type[1], &neg_neighbors);
353 EXPECT_TRUE(s.ToString().find("Wrong samples number") != std::string::npos);
354
355 neg_neighbors.reset();
356 s = graph.GetNegSampledNeighbors(node_list, 3, 3, &neg_neighbors);
357 EXPECT_TRUE(s.ToString().find("Invalid neighbor type") != std::string::npos);
358 }
359
TEST_F(MindDataTestGNNGraph,TestRandomWalk)360 TEST_F(MindDataTestGNNGraph, TestRandomWalk) {
361 std::string path = "data/mindrecord/testGraphData/sns";
362 GraphDataImpl graph(path, 1);
363 Status s = graph.Init();
364 EXPECT_TRUE(s.IsOk());
365
366 MetaInfo meta_info;
367 s = graph.GetMetaInfo(&meta_info);
368 EXPECT_TRUE(s.IsOk());
369
370 std::shared_ptr<Tensor> nodes;
371 s = graph.GetAllNodes(meta_info.node_type[0], &nodes);
372 EXPECT_TRUE(s.IsOk());
373 std::vector<NodeIdType> node_list;
374 for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) {
375 node_list.push_back(*itr);
376 }
377
378 print_int_vec(node_list, "node list ");
379 std::vector<NodeType> meta_path(59, 1);
380 std::shared_ptr<Tensor> walk_path;
381 s = graph.RandomWalk(node_list, meta_path, 2.0, 0.5, -1, &walk_path);
382 EXPECT_TRUE(s.IsOk());
383 EXPECT_TRUE(walk_path->shape().ToString() == "<33,60>");
384 }
385
TEST_F(MindDataTestGNNGraph,TestRandomWalkDefaults)386 TEST_F(MindDataTestGNNGraph, TestRandomWalkDefaults) {
387 std::string path = "data/mindrecord/testGraphData/sns";
388 GraphDataImpl graph(path, 1);
389 Status s = graph.Init();
390 EXPECT_TRUE(s.IsOk());
391
392 MetaInfo meta_info;
393 s = graph.GetMetaInfo(&meta_info);
394 EXPECT_TRUE(s.IsOk());
395
396 std::shared_ptr<Tensor> nodes;
397 s = graph.GetAllNodes(meta_info.node_type[0], &nodes);
398 EXPECT_TRUE(s.IsOk());
399 std::vector<NodeIdType> node_list;
400 for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) {
401 node_list.push_back(*itr);
402 }
403
404 print_int_vec(node_list, "node list ");
405 std::vector<NodeType> meta_path(59, 1);
406 std::shared_ptr<Tensor> walk_path;
407 s = graph.RandomWalk(node_list, meta_path, 1.0, 1.0, -1, &walk_path);
408 EXPECT_TRUE(s.IsOk());
409 EXPECT_TRUE(walk_path->shape().ToString() == "<33,60>");
410 }
411