• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <algorithm>
17 #include <string>
18 #include <map>
19 #include <memory>
20 #include <unordered_set>
21 
22 #include "common/common.h"
23 #include "gtest/gtest.h"
24 #include "minddata/dataset/util/status.h"
25 #include "minddata/dataset/engine/gnn/node.h"
26 #include "minddata/dataset/engine/gnn/graph_data_impl.h"
27 #include "minddata/dataset/engine/gnn/graph_loader.h"
28 
29 using namespace mindspore::dataset;
30 using namespace mindspore::dataset::gnn;
31 
32 #define print_int_vec(_i, _str)                                           \
33   do {                                                                    \
34     std::stringstream ss;                                                 \
35     std::copy(_i.begin(), _i.end(), std::ostream_iterator<int>(ss, " ")); \
36     MS_LOG(INFO) << _str << " " << ss.str();                              \
37   } while (false)
38 
39 class MindDataTestGNNGraph : public UT::Common {
40  protected:
41   MindDataTestGNNGraph() = default;
42 
43   using NumNeighborsMap = std::map<NodeIdType, uint32_t>;
44   using NodeNeighborsMap = std::map<NodeIdType, NumNeighborsMap>;
ParsingNeighbors(const std::shared_ptr<Tensor> & neighbors,NodeNeighborsMap & node_neighbors)45   void ParsingNeighbors(const std::shared_ptr<Tensor> &neighbors, NodeNeighborsMap &node_neighbors) {
46     auto shape_vec = neighbors->shape().AsVector();
47     uint32_t num_members = 1;
48     for (size_t i = 1; i < shape_vec.size(); ++i) {
49       num_members *= shape_vec[i];
50     }
51     uint32_t index = 0;
52     NodeIdType src_node = 0;
53     for (auto node_itr = neighbors->begin<NodeIdType>(); node_itr != neighbors->end<NodeIdType>();
54          ++node_itr, ++index) {
55       if (index % num_members == 0) {
56         src_node = *node_itr;
57         continue;
58       }
59       auto src_node_itr = node_neighbors.find(src_node);
60       if (src_node_itr == node_neighbors.end()) {
61         node_neighbors[src_node] = {{*node_itr, 1}};
62       } else {
63         auto nei_itr = src_node_itr->second.find(*node_itr);
64         if (nei_itr == src_node_itr->second.end()) {
65           src_node_itr->second[*node_itr] = 1;
66         } else {
67           src_node_itr->second[*node_itr] += 1;
68         }
69       }
70     }
71   }
72 
CheckNeighborsRatio(const NumNeighborsMap & number_neighbors,const std::vector<WeightType> & weights,float deviation_ratio=0.2)73   void CheckNeighborsRatio(const NumNeighborsMap &number_neighbors, const std::vector<WeightType> &weights,
74                            float deviation_ratio = 0.2) {
75     EXPECT_EQ(number_neighbors.size(), weights.size());
76     int index = 0;
77     uint32_t pre_num = 0;
78     WeightType pre_weight = 1;
79     for (auto neighbor : number_neighbors) {
80       if (pre_num != 0) {
81         float target_ratio = static_cast<float>(pre_weight) / static_cast<float>(weights[index]);
82         float current_ratio = static_cast<float>(pre_num) / static_cast<float>(neighbor.second);
83         float target_upper = target_ratio * (1 + deviation_ratio);
84         float target_lower = target_ratio * (1 - deviation_ratio);
85         MS_LOG(INFO) << "current_ratio:" << std::to_string(current_ratio)
86                      << " target_upper:" << std::to_string(target_upper)
87                      << " target_lower:" << std::to_string(target_lower);
88         EXPECT_LE(current_ratio, target_upper);
89         EXPECT_GE(current_ratio, target_lower);
90       }
91       pre_num = neighbor.second;
92       pre_weight = weights[index];
93       ++index;
94     }
95   }
96 };
97 
TEST_F(MindDataTestGNNGraph,TestGetEdgesFromNodes)98 TEST_F(MindDataTestGNNGraph, TestGetEdgesFromNodes) {
99   std::string path = "data/mindrecord/testGraphData/testdata";
100   GraphDataImpl graph(path, 1);
101   Status s = graph.Init();
102   EXPECT_TRUE(s.IsOk());
103 
104   std::vector<std::pair<NodeIdType, NodeIdType>> src_dst_list = {{101, 201}, {103, 207}, {108, 208},
105                                                                  {110, 201}, {204, 105}, {208, 108}};
106   std::shared_ptr<Tensor> edges;
107   s = graph.GetEdgesFromNodes(src_dst_list, &edges);
108 
109   EXPECT_TRUE(s.IsOk());
110   EXPECT_TRUE(edges->ToString() == "Tensor (shape: <6>, Type: int32)\n[1,9,17,19,31,37]");
111 }
112 
TEST_F(MindDataTestGNNGraph,TestGetAllNeighbors)113 TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
114   std::string path = "data/mindrecord/testGraphData/testdata";
115   GraphDataImpl graph(path, 1);
116   Status s = graph.Init();
117   EXPECT_TRUE(s.IsOk());
118 
119   MetaInfo meta_info;
120   s = graph.GetMetaInfo(&meta_info);
121   EXPECT_TRUE(s.IsOk());
122   EXPECT_TRUE(meta_info.node_type.size() == 2);
123 
124   std::shared_ptr<Tensor> nodes;
125   s = graph.GetAllNodes(meta_info.node_type[0], &nodes);
126   EXPECT_TRUE(s.IsOk());
127   std::vector<NodeIdType> node_list;
128   for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) {
129     node_list.push_back(*itr);
130     if (node_list.size() >= 10) {
131       break;
132     }
133   }
134   std::shared_ptr<Tensor> neighbors;
135   s = graph.GetAllNeighbors(node_list, meta_info.node_type[1], OutputFormat::kNormal, &neighbors);
136   EXPECT_TRUE(s.IsOk());
137   EXPECT_TRUE(neighbors->shape().ToString() == "<10,6>");
138   TensorRow features;
139   s = graph.GetNodeFeature(nodes, meta_info.node_feature_type, &features);
140   EXPECT_TRUE(s.IsOk());
141   EXPECT_TRUE(features.size() == 4);
142   EXPECT_TRUE(features[0]->shape().ToString() == "<10,5>");
143   EXPECT_TRUE(features[0]->ToString() ==
144               "Tensor (shape: <10,5>, Type: int32)\n"
145               "[[0,1,0,0,0],[1,0,0,0,1],[0,0,1,1,0],[0,0,0,0,0],[1,1,0,1,0],[0,0,0,0,1],[0,1,0,0,0],[0,0,0,1,1],[0,1,1,"
146               "0,0],[0,1,0,1,0]]");
147   EXPECT_TRUE(features[1]->shape().ToString() == "<10>");
148   EXPECT_TRUE(features[1]->ToString() ==
149               "Tensor (shape: <10>, Type: float32)\n[0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1]");
150   EXPECT_TRUE(features[2]->shape().ToString() == "<10>");
151   EXPECT_TRUE(features[2]->ToString() == "Tensor (shape: <10>, Type: int32)\n[1,2,3,1,4,3,5,3,5,4]");
152 }
153 
TEST_F(MindDataTestGNNGraph,TestGetAllNeighborsSpecialFormat)154 TEST_F(MindDataTestGNNGraph, TestGetAllNeighborsSpecialFormat) {
155   std::string path = "data/mindrecord/testGraphData/testdata";
156   GraphDataImpl graph(path, 1);
157   Status s = graph.Init();
158   EXPECT_TRUE(s.IsOk());
159 
160   MetaInfo meta_info;
161   s = graph.GetMetaInfo(&meta_info);
162   EXPECT_TRUE(s.IsOk());
163   EXPECT_TRUE(meta_info.node_type.size() == 2);
164 
165   std::shared_ptr<Tensor> nodes;
166   s = graph.GetAllNodes(meta_info.node_type[0], &nodes);
167   EXPECT_TRUE(s.IsOk());
168   std::vector<NodeIdType> node_list;
169   for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) {
170     node_list.push_back(*itr);
171     if (node_list.size() >= 10) {
172       break;
173     }
174   }
175   // Check COO format
176   std::shared_ptr<Tensor> neighbors_coo;
177   s = graph.GetAllNeighbors(node_list, meta_info.node_type[1], OutputFormat::kCoo, &neighbors_coo);
178   EXPECT_TRUE(s.IsOk());
179   EXPECT_TRUE(neighbors_coo->shape().ToString() == "<20,2>");
180   EXPECT_TRUE(neighbors_coo->ToString() ==
181               "Tensor (shape: <20,2>, Type: int32)\n"
182               "[[101,201],[101,205],[101,206],[102,201],[102,202],[103,203],[103,205],[103,206],[103,207],[103,208],"
183               "[105,204],[106,202],[106,203],[107,201],[107,203],[107,207],[108,208],[109,210],[110,201],[110,210]]");
184   // Check CSR format
185   std::shared_ptr<Tensor> neighbors_csr;
186   s = graph.GetAllNeighbors(node_list, meta_info.node_type[1], OutputFormat::kCsr, &neighbors_csr);
187   EXPECT_TRUE(s.IsOk());
188   EXPECT_TRUE(neighbors_csr->shape().ToString() == "<30>");
189   EXPECT_TRUE(
190     neighbors_csr->ToString() ==
191     "Tensor (shape: <30>, Type: int32)\n"
192     "[0,3,5,10,10,11,13,16,17,18,201,205,206,201,202,203,205,206,207,208,204,202,203,201,203,207,208,210,201,210]");
193 }
194 
TEST_F(MindDataTestGNNGraph,TestGetSampledNeighbors)195 TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) {
196   std::string path = "data/mindrecord/testGraphData/testdata";
197   GraphDataImpl graph(path, 1);
198   Status s = graph.Init();
199   EXPECT_TRUE(s.IsOk());
200 
201   MetaInfo meta_info;
202   s = graph.GetMetaInfo(&meta_info);
203   EXPECT_TRUE(s.IsOk());
204   EXPECT_TRUE(meta_info.node_type.size() == 2);
205 
206   std::shared_ptr<Tensor> edges;
207   s = graph.GetAllEdges(meta_info.edge_type[0], &edges);
208   EXPECT_TRUE(s.IsOk());
209   std::vector<EdgeIdType> edge_list;
210   edge_list.resize(edges->Size());
211   std::transform(edges->begin<EdgeIdType>(), edges->end<EdgeIdType>(), edge_list.begin(),
212                  [](const EdgeIdType edge) { return edge; });
213 
214   TensorRow edge_features;
215   s = graph.GetEdgeFeature(edges, meta_info.edge_feature_type, &edge_features);
216   EXPECT_TRUE(s.IsOk());
217   EXPECT_TRUE(edge_features[0]->ToString() ==
218               "Tensor (shape: <40>, Type: int32)\n"
219               "[0,1,0,0,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0]");
220   EXPECT_TRUE(edge_features[1]->ToString() ==
221               "Tensor (shape: <40>, Type: float32)\n"
222               "[0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1,1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,1.9,2,2.1,2.2,2.3,2.4,2.5,2.6,2."
223               "7,2.8,2.9,3,3.1,3.2,3.3,3.4,3.5,3.6,3.7,3.8,3.9,4]");
224 
225   std::shared_ptr<Tensor> nodes;
226   s = graph.GetNodesFromEdges(edge_list, &nodes);
227   EXPECT_TRUE(s.IsOk());
228   std::unordered_set<NodeIdType> node_set;
229   std::vector<NodeIdType> node_list;
230   int index = 0;
231   for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) {
232     index++;
233     if (index % 2 == 0) {
234       continue;
235     }
236     node_set.emplace(*itr);
237     if (node_set.size() >= 5) {
238       break;
239     }
240   }
241   node_list.resize(node_set.size());
242   std::transform(node_set.begin(), node_set.end(), node_list.begin(), [](const NodeIdType node) { return node; });
243 
244   std::shared_ptr<Tensor> neighbors;
245   {
246     MS_LOG(INFO) << "Test random sampling.";
247     NodeNeighborsMap number_neighbors;
248     int count = 0;
249     while (count < 1000) {
250       neighbors.reset();
251       s = graph.GetSampledNeighbors(node_list, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors);
252       EXPECT_TRUE(s.IsOk());
253       EXPECT_TRUE(neighbors->shape().ToString() == "<5,11>");
254       ParsingNeighbors(neighbors, number_neighbors);
255       ++count;
256     }
257     CheckNeighborsRatio(number_neighbors[103], {1, 1, 1, 1, 1});
258   }
259 
260   {
261     MS_LOG(INFO) << "Test edge weight sampling.";
262     NodeNeighborsMap number_neighbors;
263     int count = 0;
264     while (count < 1000) {
265       neighbors.reset();
266       s =
267         graph.GetSampledNeighbors(node_list, {10}, {meta_info.node_type[1]}, SamplingStrategy::kEdgeWeight, &neighbors);
268       EXPECT_TRUE(s.IsOk());
269       EXPECT_TRUE(neighbors->shape().ToString() == "<5,11>");
270       ParsingNeighbors(neighbors, number_neighbors);
271       ++count;
272     }
273     CheckNeighborsRatio(number_neighbors[103], {3, 5, 6, 7, 8});
274   }
275 
276   neighbors.reset();
277   s = graph.GetSampledNeighbors(node_list, {2, 3}, {meta_info.node_type[1], meta_info.node_type[0]},
278                                 SamplingStrategy::kRandom, &neighbors);
279   EXPECT_TRUE(s.IsOk());
280   EXPECT_TRUE(neighbors->shape().ToString() == "<5,9>");
281 
282   neighbors.reset();
283   s = graph.GetSampledNeighbors(node_list, {2, 3, 4},
284                                 {meta_info.node_type[1], meta_info.node_type[0], meta_info.node_type[1]},
285                                 SamplingStrategy::kRandom, &neighbors);
286   EXPECT_TRUE(s.IsOk());
287   EXPECT_TRUE(neighbors->shape().ToString() == "<5,33>");
288 
289   neighbors.reset();
290   s = graph.GetSampledNeighbors({}, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors);
291   EXPECT_TRUE(s.ToString().find("Input node_list is empty.") != std::string::npos);
292 
293   neighbors.reset();
294   s = graph.GetSampledNeighbors({-1, 1}, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors);
295   EXPECT_TRUE(s.ToString().find("Invalid node id") != std::string::npos);
296 
297   neighbors.reset();
298   s = graph.GetSampledNeighbors(node_list, {2, 50}, {meta_info.node_type[0], meta_info.node_type[1]},
299                                 SamplingStrategy::kRandom, &neighbors);
300   EXPECT_TRUE(s.ToString().find("Wrong samples number") != std::string::npos);
301 
302   neighbors.reset();
303   s = graph.GetSampledNeighbors(node_list, {2}, {5}, SamplingStrategy::kRandom, &neighbors);
304   EXPECT_TRUE(s.ToString().find("Invalid neighbor type") != std::string::npos);
305 
306   neighbors.reset();
307   s = graph.GetSampledNeighbors(node_list, {2, 3, 4}, {meta_info.node_type[1], meta_info.node_type[0]},
308                                 SamplingStrategy::kRandom, &neighbors);
309   EXPECT_TRUE(s.ToString().find("The sizes of neighbor_nums and neighbor_types are inconsistent.") !=
310               std::string::npos);
311 
312   neighbors.reset();
313   s = graph.GetSampledNeighbors({301}, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors);
314   EXPECT_TRUE(s.ToString().find("Invalid node id:301") != std::string::npos);
315 }
316 
TEST_F(MindDataTestGNNGraph,TestGetNegSampledNeighbors)317 TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) {
318   std::string path = "data/mindrecord/testGraphData/testdata";
319   GraphDataImpl graph(path, 1);
320   Status s = graph.Init();
321   EXPECT_TRUE(s.IsOk());
322 
323   MetaInfo meta_info;
324   s = graph.GetMetaInfo(&meta_info);
325   EXPECT_TRUE(s.IsOk());
326   EXPECT_TRUE(meta_info.node_type.size() == 2);
327 
328   std::shared_ptr<Tensor> nodes;
329   s = graph.GetAllNodes(meta_info.node_type[0], &nodes);
330   EXPECT_TRUE(s.IsOk());
331   std::vector<NodeIdType> node_list;
332   for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) {
333     node_list.push_back(*itr);
334     if (node_list.size() >= 10) {
335       break;
336     }
337   }
338   std::shared_ptr<Tensor> neg_neighbors;
339   s = graph.GetNegSampledNeighbors(node_list, 3, meta_info.node_type[1], &neg_neighbors);
340   EXPECT_TRUE(s.IsOk());
341   EXPECT_TRUE(neg_neighbors->shape().ToString() == "<10,4>");
342 
343   neg_neighbors.reset();
344   s = graph.GetNegSampledNeighbors({}, 3, meta_info.node_type[1], &neg_neighbors);
345   EXPECT_TRUE(s.ToString().find("Input node_list is empty.") != std::string::npos);
346 
347   neg_neighbors.reset();
348   s = graph.GetNegSampledNeighbors({-1, 1}, 3, meta_info.node_type[1], &neg_neighbors);
349   EXPECT_TRUE(s.ToString().find("Invalid node id") != std::string::npos);
350 
351   neg_neighbors.reset();
352   s = graph.GetNegSampledNeighbors(node_list, 50, meta_info.node_type[1], &neg_neighbors);
353   EXPECT_TRUE(s.ToString().find("Wrong samples number") != std::string::npos);
354 
355   neg_neighbors.reset();
356   s = graph.GetNegSampledNeighbors(node_list, 3, 3, &neg_neighbors);
357   EXPECT_TRUE(s.ToString().find("Invalid neighbor type") != std::string::npos);
358 }
359 
TEST_F(MindDataTestGNNGraph,TestRandomWalk)360 TEST_F(MindDataTestGNNGraph, TestRandomWalk) {
361   std::string path = "data/mindrecord/testGraphData/sns";
362   GraphDataImpl graph(path, 1);
363   Status s = graph.Init();
364   EXPECT_TRUE(s.IsOk());
365 
366   MetaInfo meta_info;
367   s = graph.GetMetaInfo(&meta_info);
368   EXPECT_TRUE(s.IsOk());
369 
370   std::shared_ptr<Tensor> nodes;
371   s = graph.GetAllNodes(meta_info.node_type[0], &nodes);
372   EXPECT_TRUE(s.IsOk());
373   std::vector<NodeIdType> node_list;
374   for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) {
375     node_list.push_back(*itr);
376   }
377 
378   print_int_vec(node_list, "node list ");
379   std::vector<NodeType> meta_path(59, 1);
380   std::shared_ptr<Tensor> walk_path;
381   s = graph.RandomWalk(node_list, meta_path, 2.0, 0.5, -1, &walk_path);
382   EXPECT_TRUE(s.IsOk());
383   EXPECT_TRUE(walk_path->shape().ToString() == "<33,60>");
384 }
385 
TEST_F(MindDataTestGNNGraph,TestRandomWalkDefaults)386 TEST_F(MindDataTestGNNGraph, TestRandomWalkDefaults) {
387   std::string path = "data/mindrecord/testGraphData/sns";
388   GraphDataImpl graph(path, 1);
389   Status s = graph.Init();
390   EXPECT_TRUE(s.IsOk());
391 
392   MetaInfo meta_info;
393   s = graph.GetMetaInfo(&meta_info);
394   EXPECT_TRUE(s.IsOk());
395 
396   std::shared_ptr<Tensor> nodes;
397   s = graph.GetAllNodes(meta_info.node_type[0], &nodes);
398   EXPECT_TRUE(s.IsOk());
399   std::vector<NodeIdType> node_list;
400   for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) {
401     node_list.push_back(*itr);
402   }
403 
404   print_int_vec(node_list, "node list ");
405   std::vector<NodeType> meta_path(59, 1);
406   std::shared_ptr<Tensor> walk_path;
407   s = graph.RandomWalk(node_list, meta_path, 1.0, 1.0, -1, &walk_path);
408   EXPECT_TRUE(s.IsOk());
409   EXPECT_TRUE(walk_path->shape().ToString() == "<33,60>");
410 }
411