1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include <algorithm> 17 #include <string> 18 #include <map> 19 #include <memory> 20 #include <unordered_set> 21 22 #include "common/common.h" 23 #include "gtest/gtest.h" 24 #include "minddata/dataset/util/status.h" 25 #include "minddata/dataset/engine/gnn/node.h" 26 #include "minddata/dataset/engine/gnn/graph_data_impl.h" 27 #include "minddata/dataset/engine/gnn/graph_loader.h" 28 29 using namespace mindspore::dataset; 30 using namespace mindspore::dataset::gnn; 31 32 #define print_int_vec(_i, _str) \ 33 do { \ 34 std::stringstream ss; \ 35 std::copy(_i.begin(), _i.end(), std::ostream_iterator<int>(ss, " ")); \ 36 MS_LOG(INFO) << _str << " " << ss.str(); \ 37 } while (false) 38 39 class MindDataTestGNNGraph : public UT::Common { 40 protected: 41 MindDataTestGNNGraph() = default; 42 43 using NumNeighborsMap = std::map<NodeIdType, uint32_t>; 44 using NodeNeighborsMap = std::map<NodeIdType, NumNeighborsMap>; 45 void ParsingNeighbors(const std::shared_ptr<Tensor> &neighbors, NodeNeighborsMap &node_neighbors) { 46 auto shape_vec = neighbors->shape().AsVector(); 47 uint32_t num_members = 1; 48 for (size_t i = 1; i < shape_vec.size(); ++i) { 49 num_members *= shape_vec[i]; 50 } 51 uint32_t index = 0; 52 NodeIdType src_node = 0; 53 for (auto node_itr = neighbors->begin<NodeIdType>(); node_itr != neighbors->end<NodeIdType>(); 54 ++node_itr, ++index) { 55 if (index % num_members == 0) { 56 src_node = *node_itr; 57 continue; 58 } 59 auto src_node_itr = node_neighbors.find(src_node); 60 if (src_node_itr == node_neighbors.end()) { 61 node_neighbors[src_node] = {{*node_itr, 1}}; 62 } else { 63 auto nei_itr = src_node_itr->second.find(*node_itr); 64 if (nei_itr == src_node_itr->second.end()) { 65 src_node_itr->second[*node_itr] = 1; 66 } else { 67 src_node_itr->second[*node_itr] += 1; 68 } 69 } 70 } 71 } 72 73 void CheckNeighborsRatio(const NumNeighborsMap &number_neighbors, const std::vector<WeightType> &weights, 74 float deviation_ratio = 0.2) { 75 EXPECT_EQ(number_neighbors.size(), weights.size()); 76 int index = 0; 77 uint32_t pre_num = 0; 78 WeightType pre_weight = 1; 79 for (auto neighbor : number_neighbors) { 80 if (pre_num != 0) { 81 float target_ratio = static_cast<float>(pre_weight) / static_cast<float>(weights[index]); 82 float current_ratio = static_cast<float>(pre_num) / static_cast<float>(neighbor.second); 83 float target_upper = target_ratio * (1 + deviation_ratio); 84 float target_lower = target_ratio * (1 - deviation_ratio); 85 MS_LOG(INFO) << "current_ratio:" << std::to_string(current_ratio) 86 << " target_upper:" << std::to_string(target_upper) 87 << " target_lower:" << std::to_string(target_lower); 88 EXPECT_LE(current_ratio, target_upper); 89 EXPECT_GE(current_ratio, target_lower); 90 } 91 pre_num = neighbor.second; 92 pre_weight = weights[index]; 93 ++index; 94 } 95 } 96 }; 97 98 TEST_F(MindDataTestGNNGraph, TestGetEdgesFromNodes) { 99 std::string path = "data/mindrecord/testGraphData/testdata"; 100 GraphDataImpl graph(path, 1); 101 Status s = graph.Init(); 102 EXPECT_TRUE(s.IsOk()); 103 104 std::vector<std::pair<NodeIdType, NodeIdType>> src_dst_list = {{101, 201}, {103, 207}, {108, 208}, 105 {110, 201}, {204, 105}, {208, 108}}; 106 std::shared_ptr<Tensor> edges; 107 s = graph.GetEdgesFromNodes(src_dst_list, &edges); 108 109 EXPECT_TRUE(s.IsOk()); 110 EXPECT_TRUE(edges->ToString() == "Tensor (shape: <6>, Type: int32)\n[1,9,17,19,31,37]"); 111 } 112 113 TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) { 114 std::string path = "data/mindrecord/testGraphData/testdata"; 115 GraphDataImpl graph(path, 1); 116 Status s = graph.Init(); 117 EXPECT_TRUE(s.IsOk()); 118 119 MetaInfo meta_info; 120 s = graph.GetMetaInfo(&meta_info); 121 EXPECT_TRUE(s.IsOk()); 122 EXPECT_TRUE(meta_info.node_type.size() == 2); 123 124 std::shared_ptr<Tensor> nodes; 125 s = graph.GetAllNodes(meta_info.node_type[0], &nodes); 126 EXPECT_TRUE(s.IsOk()); 127 std::vector<NodeIdType> node_list; 128 for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) { 129 node_list.push_back(*itr); 130 if (node_list.size() >= 10) { 131 break; 132 } 133 } 134 std::shared_ptr<Tensor> neighbors; 135 s = graph.GetAllNeighbors(node_list, meta_info.node_type[1], OutputFormat::kNormal, &neighbors); 136 EXPECT_TRUE(s.IsOk()); 137 EXPECT_TRUE(neighbors->shape().ToString() == "<10,6>"); 138 TensorRow features; 139 s = graph.GetNodeFeature(nodes, meta_info.node_feature_type, &features); 140 EXPECT_TRUE(s.IsOk()); 141 EXPECT_TRUE(features.size() == 4); 142 EXPECT_TRUE(features[0]->shape().ToString() == "<10,5>"); 143 EXPECT_TRUE(features[0]->ToString() == 144 "Tensor (shape: <10,5>, Type: int32)\n" 145 "[[0,1,0,0,0],[1,0,0,0,1],[0,0,1,1,0],[0,0,0,0,0],[1,1,0,1,0],[0,0,0,0,1],[0,1,0,0,0],[0,0,0,1,1],[0,1,1," 146 "0,0],[0,1,0,1,0]]"); 147 EXPECT_TRUE(features[1]->shape().ToString() == "<10>"); 148 EXPECT_TRUE(features[1]->ToString() == 149 "Tensor (shape: <10>, Type: float32)\n[0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1]"); 150 EXPECT_TRUE(features[2]->shape().ToString() == "<10>"); 151 EXPECT_TRUE(features[2]->ToString() == "Tensor (shape: <10>, Type: int32)\n[1,2,3,1,4,3,5,3,5,4]"); 152 } 153 154 TEST_F(MindDataTestGNNGraph, TestGetAllNeighborsSpecialFormat) { 155 std::string path = "data/mindrecord/testGraphData/testdata"; 156 GraphDataImpl graph(path, 1); 157 Status s = graph.Init(); 158 EXPECT_TRUE(s.IsOk()); 159 160 MetaInfo meta_info; 161 s = graph.GetMetaInfo(&meta_info); 162 EXPECT_TRUE(s.IsOk()); 163 EXPECT_TRUE(meta_info.node_type.size() == 2); 164 165 std::shared_ptr<Tensor> nodes; 166 s = graph.GetAllNodes(meta_info.node_type[0], &nodes); 167 EXPECT_TRUE(s.IsOk()); 168 std::vector<NodeIdType> node_list; 169 for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) { 170 node_list.push_back(*itr); 171 if (node_list.size() >= 10) { 172 break; 173 } 174 } 175 // Check COO format 176 std::shared_ptr<Tensor> neighbors_coo; 177 s = graph.GetAllNeighbors(node_list, meta_info.node_type[1], OutputFormat::kCoo, &neighbors_coo); 178 EXPECT_TRUE(s.IsOk()); 179 EXPECT_TRUE(neighbors_coo->shape().ToString() == "<20,2>"); 180 EXPECT_TRUE(neighbors_coo->ToString() == 181 "Tensor (shape: <20,2>, Type: int32)\n" 182 "[[101,201],[101,205],[101,206],[102,201],[102,202],[103,203],[103,205],[103,206],[103,207],[103,208]," 183 "[105,204],[106,202],[106,203],[107,201],[107,203],[107,207],[108,208],[109,210],[110,201],[110,210]]"); 184 // Check CSR format 185 std::shared_ptr<Tensor> neighbors_csr; 186 s = graph.GetAllNeighbors(node_list, meta_info.node_type[1], OutputFormat::kCsr, &neighbors_csr); 187 EXPECT_TRUE(s.IsOk()); 188 EXPECT_TRUE(neighbors_csr->shape().ToString() == "<30>"); 189 EXPECT_TRUE( 190 neighbors_csr->ToString() == 191 "Tensor (shape: <30>, Type: int32)\n" 192 "[0,3,5,10,10,11,13,16,17,18,201,205,206,201,202,203,205,206,207,208,204,202,203,201,203,207,208,210,201,210]"); 193 } 194 195 TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) { 196 std::string path = "data/mindrecord/testGraphData/testdata"; 197 GraphDataImpl graph(path, 1); 198 Status s = graph.Init(); 199 EXPECT_TRUE(s.IsOk()); 200 201 MetaInfo meta_info; 202 s = graph.GetMetaInfo(&meta_info); 203 EXPECT_TRUE(s.IsOk()); 204 EXPECT_TRUE(meta_info.node_type.size() == 2); 205 206 std::shared_ptr<Tensor> edges; 207 s = graph.GetAllEdges(meta_info.edge_type[0], &edges); 208 EXPECT_TRUE(s.IsOk()); 209 std::vector<EdgeIdType> edge_list; 210 edge_list.resize(edges->Size()); 211 std::transform(edges->begin<EdgeIdType>(), edges->end<EdgeIdType>(), edge_list.begin(), 212 [](const EdgeIdType edge) { return edge; }); 213 214 TensorRow edge_features; 215 s = graph.GetEdgeFeature(edges, meta_info.edge_feature_type, &edge_features); 216 EXPECT_TRUE(s.IsOk()); 217 EXPECT_TRUE(edge_features[0]->ToString() == 218 "Tensor (shape: <40>, Type: int32)\n" 219 "[0,1,0,0,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0]"); 220 EXPECT_TRUE(edge_features[1]->ToString() == 221 "Tensor (shape: <40>, Type: float32)\n" 222 "[0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1,1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,1.9,2,2.1,2.2,2.3,2.4,2.5,2.6,2." 223 "7,2.8,2.9,3,3.1,3.2,3.3,3.4,3.5,3.6,3.7,3.8,3.9,4]"); 224 225 std::shared_ptr<Tensor> nodes; 226 s = graph.GetNodesFromEdges(edge_list, &nodes); 227 EXPECT_TRUE(s.IsOk()); 228 std::unordered_set<NodeIdType> node_set; 229 std::vector<NodeIdType> node_list; 230 int index = 0; 231 for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) { 232 index++; 233 if (index % 2 == 0) { 234 continue; 235 } 236 node_set.emplace(*itr); 237 if (node_set.size() >= 5) { 238 break; 239 } 240 } 241 node_list.resize(node_set.size()); 242 std::transform(node_set.begin(), node_set.end(), node_list.begin(), [](const NodeIdType node) { return node; }); 243 244 std::shared_ptr<Tensor> neighbors; 245 { 246 MS_LOG(INFO) << "Test random sampling."; 247 NodeNeighborsMap number_neighbors; 248 int count = 0; 249 while (count < 1000) { 250 neighbors.reset(); 251 s = graph.GetSampledNeighbors(node_list, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors); 252 EXPECT_TRUE(s.IsOk()); 253 EXPECT_TRUE(neighbors->shape().ToString() == "<5,11>"); 254 ParsingNeighbors(neighbors, number_neighbors); 255 ++count; 256 } 257 CheckNeighborsRatio(number_neighbors[103], {1, 1, 1, 1, 1}); 258 } 259 260 { 261 MS_LOG(INFO) << "Test edge weight sampling."; 262 NodeNeighborsMap number_neighbors; 263 int count = 0; 264 while (count < 1000) { 265 neighbors.reset(); 266 s = 267 graph.GetSampledNeighbors(node_list, {10}, {meta_info.node_type[1]}, SamplingStrategy::kEdgeWeight, &neighbors); 268 EXPECT_TRUE(s.IsOk()); 269 EXPECT_TRUE(neighbors->shape().ToString() == "<5,11>"); 270 ParsingNeighbors(neighbors, number_neighbors); 271 ++count; 272 } 273 CheckNeighborsRatio(number_neighbors[103], {3, 5, 6, 7, 8}); 274 } 275 276 neighbors.reset(); 277 s = graph.GetSampledNeighbors(node_list, {2, 3}, {meta_info.node_type[1], meta_info.node_type[0]}, 278 SamplingStrategy::kRandom, &neighbors); 279 EXPECT_TRUE(s.IsOk()); 280 EXPECT_TRUE(neighbors->shape().ToString() == "<5,9>"); 281 282 neighbors.reset(); 283 s = graph.GetSampledNeighbors(node_list, {2, 3, 4}, 284 {meta_info.node_type[1], meta_info.node_type[0], meta_info.node_type[1]}, 285 SamplingStrategy::kRandom, &neighbors); 286 EXPECT_TRUE(s.IsOk()); 287 EXPECT_TRUE(neighbors->shape().ToString() == "<5,33>"); 288 289 neighbors.reset(); 290 s = graph.GetSampledNeighbors({}, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors); 291 EXPECT_TRUE(s.ToString().find("Input node_list is empty.") != std::string::npos); 292 293 neighbors.reset(); 294 s = graph.GetSampledNeighbors({-1, 1}, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors); 295 EXPECT_TRUE(s.ToString().find("Invalid node id") != std::string::npos); 296 297 neighbors.reset(); 298 s = graph.GetSampledNeighbors(node_list, {2, 50}, {meta_info.node_type[0], meta_info.node_type[1]}, 299 SamplingStrategy::kRandom, &neighbors); 300 EXPECT_TRUE(s.ToString().find("Wrong samples number") != std::string::npos); 301 302 neighbors.reset(); 303 s = graph.GetSampledNeighbors(node_list, {2}, {5}, SamplingStrategy::kRandom, &neighbors); 304 EXPECT_TRUE(s.ToString().find("Invalid neighbor type") != std::string::npos); 305 306 neighbors.reset(); 307 s = graph.GetSampledNeighbors(node_list, {2, 3, 4}, {meta_info.node_type[1], meta_info.node_type[0]}, 308 SamplingStrategy::kRandom, &neighbors); 309 EXPECT_TRUE(s.ToString().find("The sizes of neighbor_nums and neighbor_types are inconsistent.") != 310 std::string::npos); 311 312 neighbors.reset(); 313 s = graph.GetSampledNeighbors({301}, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors); 314 EXPECT_TRUE(s.ToString().find("Invalid node id:301") != std::string::npos); 315 } 316 317 TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) { 318 std::string path = "data/mindrecord/testGraphData/testdata"; 319 GraphDataImpl graph(path, 1); 320 Status s = graph.Init(); 321 EXPECT_TRUE(s.IsOk()); 322 323 MetaInfo meta_info; 324 s = graph.GetMetaInfo(&meta_info); 325 EXPECT_TRUE(s.IsOk()); 326 EXPECT_TRUE(meta_info.node_type.size() == 2); 327 328 std::shared_ptr<Tensor> nodes; 329 s = graph.GetAllNodes(meta_info.node_type[0], &nodes); 330 EXPECT_TRUE(s.IsOk()); 331 std::vector<NodeIdType> node_list; 332 for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) { 333 node_list.push_back(*itr); 334 if (node_list.size() >= 10) { 335 break; 336 } 337 } 338 std::shared_ptr<Tensor> neg_neighbors; 339 s = graph.GetNegSampledNeighbors(node_list, 3, meta_info.node_type[1], &neg_neighbors); 340 EXPECT_TRUE(s.IsOk()); 341 EXPECT_TRUE(neg_neighbors->shape().ToString() == "<10,4>"); 342 343 neg_neighbors.reset(); 344 s = graph.GetNegSampledNeighbors({}, 3, meta_info.node_type[1], &neg_neighbors); 345 EXPECT_TRUE(s.ToString().find("Input node_list is empty.") != std::string::npos); 346 347 neg_neighbors.reset(); 348 s = graph.GetNegSampledNeighbors({-1, 1}, 3, meta_info.node_type[1], &neg_neighbors); 349 EXPECT_TRUE(s.ToString().find("Invalid node id") != std::string::npos); 350 351 neg_neighbors.reset(); 352 s = graph.GetNegSampledNeighbors(node_list, 50, meta_info.node_type[1], &neg_neighbors); 353 EXPECT_TRUE(s.ToString().find("Wrong samples number") != std::string::npos); 354 355 neg_neighbors.reset(); 356 s = graph.GetNegSampledNeighbors(node_list, 3, 3, &neg_neighbors); 357 EXPECT_TRUE(s.ToString().find("Invalid neighbor type") != std::string::npos); 358 } 359 360 TEST_F(MindDataTestGNNGraph, TestRandomWalk) { 361 std::string path = "data/mindrecord/testGraphData/sns"; 362 GraphDataImpl graph(path, 1); 363 Status s = graph.Init(); 364 EXPECT_TRUE(s.IsOk()); 365 366 MetaInfo meta_info; 367 s = graph.GetMetaInfo(&meta_info); 368 EXPECT_TRUE(s.IsOk()); 369 370 std::shared_ptr<Tensor> nodes; 371 s = graph.GetAllNodes(meta_info.node_type[0], &nodes); 372 EXPECT_TRUE(s.IsOk()); 373 std::vector<NodeIdType> node_list; 374 for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) { 375 node_list.push_back(*itr); 376 } 377 378 print_int_vec(node_list, "node list "); 379 std::vector<NodeType> meta_path(59, 1); 380 std::shared_ptr<Tensor> walk_path; 381 s = graph.RandomWalk(node_list, meta_path, 2.0, 0.5, -1, &walk_path); 382 EXPECT_TRUE(s.IsOk()); 383 EXPECT_TRUE(walk_path->shape().ToString() == "<33,60>"); 384 } 385 386 TEST_F(MindDataTestGNNGraph, TestRandomWalkDefaults) { 387 std::string path = "data/mindrecord/testGraphData/sns"; 388 GraphDataImpl graph(path, 1); 389 Status s = graph.Init(); 390 EXPECT_TRUE(s.IsOk()); 391 392 MetaInfo meta_info; 393 s = graph.GetMetaInfo(&meta_info); 394 EXPECT_TRUE(s.IsOk()); 395 396 std::shared_ptr<Tensor> nodes; 397 s = graph.GetAllNodes(meta_info.node_type[0], &nodes); 398 EXPECT_TRUE(s.IsOk()); 399 std::vector<NodeIdType> node_list; 400 for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) { 401 node_list.push_back(*itr); 402 } 403 404 print_int_vec(node_list, "node list "); 405 std::vector<NodeType> meta_path(59, 1); 406 std::shared_ptr<Tensor> walk_path; 407 s = graph.RandomWalk(node_list, meta_path, 1.0, 1.0, -1, &walk_path); 408 EXPECT_TRUE(s.IsOk()); 409 EXPECT_TRUE(walk_path->shape().ToString() == "<33,60>"); 410 } 411