1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import random 16import pytest 17import numpy as np 18import mindspore.dataset as ds 19from mindspore import log as logger 20from mindspore.dataset.engine import SamplingStrategy 21from mindspore.dataset.engine import OutputFormat 22 23DATASET_FILE = "../data/mindrecord/testGraphData/testdata" 24SOCIAL_DATA_FILE = "../data/mindrecord/testGraphData/sns" 25 26 27def test_graphdata_getfullneighbor(): 28 """ 29 Test get all neighbors 30 """ 31 logger.info('test get all neighbors.\n') 32 g = ds.GraphData(DATASET_FILE, 2) 33 nodes = g.get_all_nodes(1) 34 assert len(nodes) == 10 35 neighbor = g.get_all_neighbors(nodes, 2) 36 assert neighbor.shape == (10, 6) 37 row_tensor = g.get_node_feature(neighbor.tolist(), [2, 3]) 38 assert row_tensor[0].shape == (10, 6) 39 40 41def test_graphdata_getallneighbors_special_format(): 42 """ 43 Test get all neighbors with special format 44 """ 45 logger.info('test get all neighbors with special format.\n') 46 g = ds.GraphData(DATASET_FILE, 2) 47 nodes = g.get_all_nodes(1) 48 assert len(nodes) == 10 49 50 neighbor_coo = g.get_all_neighbors(nodes, 2, OutputFormat.COO) 51 assert neighbor_coo.shape == (20, 2) 52 53 offset_table, neighbor_csr = g.get_all_neighbors(nodes, 2, OutputFormat.CSR) 54 assert offset_table.shape == (10,) 55 assert neighbor_csr.shape == (20,) 56 57 58def test_graphdata_getnodefeature_input_check(): 59 """ 60 Test get node feature input check 61 """ 62 logger.info('test getnodefeature input check.\n') 63 g = ds.GraphData(DATASET_FILE) 64 with pytest.raises(TypeError): 65 input_list = [1, [1, 1]] 66 g.get_node_feature(input_list, [1]) 67 68 with pytest.raises(TypeError): 69 input_list = [[1, 1], 1] 70 g.get_node_feature(input_list, [1]) 71 72 with pytest.raises(TypeError): 73 input_list = [[1, 1], [1, 1, 1]] 74 g.get_node_feature(input_list, [1]) 75 76 with pytest.raises(TypeError): 77 input_list = [[1, 1, 1], [1, 1]] 78 g.get_node_feature(input_list, [1]) 79 80 with pytest.raises(TypeError): 81 input_list = [[1, 1], [1, [1, 1]]] 82 g.get_node_feature(input_list, [1]) 83 84 with pytest.raises(TypeError): 85 input_list = [[1, 1], [[1, 1], 1]] 86 g.get_node_feature(input_list, [1]) 87 88 with pytest.raises(TypeError): 89 input_list = [[1, 1], [1, 1]] 90 g.get_node_feature(input_list, 1) 91 92 with pytest.raises(TypeError): 93 input_list = [[1, 0.1], [1, 1]] 94 g.get_node_feature(input_list, 1) 95 96 with pytest.raises(TypeError): 97 input_list = np.array([[1, 0.1], [1, 1]]) 98 g.get_node_feature(input_list, 1) 99 100 with pytest.raises(TypeError): 101 input_list = [[1, 1], [1, 1]] 102 g.get_node_feature(input_list, ["a"]) 103 104 with pytest.raises(TypeError): 105 input_list = [[1, 1], [1, 1]] 106 g.get_node_feature(input_list, [1, "a"]) 107 108 109def test_graphdata_getsampledneighbors(): 110 """ 111 Test sampled neighbors 112 """ 113 logger.info('test get sampled neighbors.\n') 114 g = ds.GraphData(DATASET_FILE, 1) 115 edges = g.get_all_edges(0) 116 nodes = g.get_nodes_from_edges(edges) 117 assert len(nodes) == 40 118 neighbor = g.get_sampled_neighbors( 119 np.unique(nodes[0:21, 0]), [2, 3], [2, 1], SamplingStrategy.RANDOM) 120 assert neighbor.shape == (10, 9) 121 neighbor = g.get_sampled_neighbors( 122 np.unique(nodes[0:21, 0]), [2, 3], [2, 1], SamplingStrategy.EDGE_WEIGHT) 123 assert neighbor.shape == (10, 9) 124 125 126def test_graphdata_getnegsampledneighbors(): 127 """ 128 Test neg sampled neighbors 129 """ 130 logger.info('test get negative sampled neighbors.\n') 131 g = ds.GraphData(DATASET_FILE, 2) 132 nodes = g.get_all_nodes(1) 133 assert len(nodes) == 10 134 neighbor = g.get_neg_sampled_neighbors(nodes, 5, 2) 135 assert neighbor.shape == (10, 6) 136 137 138def test_graphdata_graphinfo(): 139 """ 140 Test graph info 141 """ 142 logger.info('test graph info.\n') 143 g = ds.GraphData(DATASET_FILE, 2) 144 graph_info = g.graph_info() 145 assert graph_info['node_type'] == [1, 2] 146 assert graph_info['edge_type'] == [0] 147 assert graph_info['node_num'] == {1: 10, 2: 10} 148 assert graph_info['edge_num'] == {0: 40} 149 assert graph_info['node_feature_type'] == [1, 2, 3, 4] 150 assert graph_info['edge_feature_type'] == [1, 2] 151 152 153class RandomBatchedSampler(ds.Sampler): 154 # RandomBatchedSampler generate random sequence without replacement in a batched manner 155 def __init__(self, index_range, num_edges_per_sample): 156 super().__init__() 157 self.index_range = index_range 158 self.num_edges_per_sample = num_edges_per_sample 159 160 def __iter__(self): 161 indices = [i+1 for i in range(self.index_range)] 162 # Reset random seed here if necessary 163 # random.seed(0) 164 random.shuffle(indices) 165 for i in range(0, self.index_range, self.num_edges_per_sample): 166 # Drop reminder 167 if i + self.num_edges_per_sample <= self.index_range: 168 yield indices[i: i + self.num_edges_per_sample] 169 170 171class GNNGraphDataset(): 172 def __init__(self, g, batch_num): 173 self.g = g 174 self.batch_num = batch_num 175 176 def __len__(self): 177 # Total sample size of GNN dataset 178 # In this case, the size should be total_num_edges/num_edges_per_sample 179 return self.g.graph_info()['edge_num'][0] // self.batch_num 180 181 def __getitem__(self, index): 182 # index will be a list of indices yielded from RandomBatchedSampler 183 # Fetch edges/nodes/samples/features based on indices 184 nodes = self.g.get_nodes_from_edges(index.astype(np.int32)) 185 nodes = nodes[:, 0] 186 neg_nodes = self.g.get_neg_sampled_neighbors( 187 node_list=nodes, neg_neighbor_num=3, neg_neighbor_type=1) 188 nodes_neighbors = self.g.get_sampled_neighbors(node_list=nodes, neighbor_nums=[ 189 2, 2], neighbor_types=[2, 1]) 190 neg_nodes_neighbors = self.g.get_sampled_neighbors( 191 node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2], neighbor_types=[2, 2]) 192 nodes_neighbors_features = self.g.get_node_feature( 193 node_list=nodes_neighbors, feature_types=[2, 3]) 194 neg_neighbors_features = self.g.get_node_feature( 195 node_list=neg_nodes_neighbors, feature_types=[2, 3]) 196 return nodes_neighbors, neg_nodes_neighbors, nodes_neighbors_features[0], neg_neighbors_features[1] 197 198 199def test_graphdata_generatordataset(): 200 """ 201 Test generator dataset 202 """ 203 logger.info('test generator dataset.\n') 204 205 #reduce memory required by disabling the shm optimization 206 mem_original = ds.config.get_enable_shared_mem() 207 ds.config.set_enable_shared_mem(False) 208 209 g = ds.GraphData(DATASET_FILE) 210 batch_num = 2 211 edge_num = g.graph_info()['edge_num'][0] 212 out_column_names = ["neighbors", "neg_neighbors", "neighbors_features", "neg_neighbors_features"] 213 dataset = ds.GeneratorDataset(source=GNNGraphDataset(g, batch_num), column_names=out_column_names, 214 sampler=RandomBatchedSampler(edge_num, batch_num), num_parallel_workers=4) 215 dataset = dataset.repeat(2) 216 itr = dataset.create_dict_iterator(num_epochs=1, output_numpy=True) 217 i = 0 218 for data in itr: 219 assert data['neighbors'].shape == (2, 7) 220 assert data['neg_neighbors'].shape == (6, 7) 221 assert data['neighbors_features'].shape == (2, 7) 222 assert data['neg_neighbors_features'].shape == (6, 7) 223 i += 1 224 assert i == 40 225 226 ds.config.set_enable_shared_mem(mem_original) 227 228def test_graphdata_randomwalkdefault(): 229 """ 230 Test random walk defaults 231 """ 232 logger.info('test randomwalk with default parameters.\n') 233 g = ds.GraphData(SOCIAL_DATA_FILE, 1) 234 nodes = g.get_all_nodes(1) 235 assert len(nodes) == 33 236 237 meta_path = [1 for _ in range(39)] 238 walks = g.random_walk(nodes, meta_path) 239 assert walks.shape == (33, 40) 240 241 242def test_graphdata_randomwalk(): 243 """ 244 Test random walk 245 """ 246 logger.info('test random walk with given parameters.\n') 247 g = ds.GraphData(SOCIAL_DATA_FILE, 1) 248 nodes = g.get_all_nodes(1) 249 assert len(nodes) == 33 250 251 meta_path = [1 for _ in range(39)] 252 walks = g.random_walk(nodes, meta_path, 2.0, 0.5, -1) 253 assert walks.shape == (33, 40) 254 255 256def test_graphdata_getedgefeature(): 257 """ 258 Test get edge feature 259 """ 260 logger.info('test get_edge_feature.\n') 261 g = ds.GraphData(DATASET_FILE) 262 edges = g.get_all_edges(0) 263 features = g.get_edge_feature(edges, [1, 2]) 264 assert features[0].shape == (40,) 265 assert features[1].shape == (40,) 266 267 268def test_graphdata_getedgefeature_invalidcase(): 269 """ 270 Test get edge feature with invalid edge id, 0 should be returned for those invalid edge id in correct index 271 """ 272 logger.info('test get_edge_feature.\n') 273 g = ds.GraphData(DATASET_FILE) 274 edges = g.get_all_edges(0) 275 edges[-6] = -1 276 features = g.get_edge_feature(edges, [1, 2]) 277 assert features[0].shape == (40,) 278 assert features[1].shape == (40,) 279 assert features[0][-6] == 0 280 assert features[1][-6] == 0. 281 282 283def test_graphdata_getnodefeature_invalidcase(): 284 """ 285 Test get node feature with invalid node id, 0 should be returned for those invalid node id in correct index 286 """ 287 logger.info('test get_node_feature.\n') 288 g = ds.GraphData(DATASET_FILE) 289 nodes = g.get_all_nodes(node_type=1) 290 nodes[5] = -1 291 features = g.get_node_feature(node_list=nodes, feature_types=[2, 3]) 292 assert features[0].shape == (10,) 293 assert features[1].shape == (10,) 294 assert features[0][5] == 0. 295 assert features[1][5] == 0 296 297 298def test_graphdata_getedgesfromnodes(): 299 """ 300 Test get edges from nodes 301 """ 302 logger.info('test get_edges_from_nodes\n') 303 g = ds.GraphData(DATASET_FILE) 304 305 nodes_pair_list = [(101, 201), (103, 207), (204, 105), (108, 208), (110, 210), (210, 110)] 306 edges = g.get_edges_from_nodes(node_list=nodes_pair_list) 307 assert edges.tolist() == [1, 9, 31, 17, 20, 40] 308 309 310if __name__ == '__main__': 311 test_graphdata_getfullneighbor() 312 test_graphdata_getnodefeature_input_check() 313 test_graphdata_getsampledneighbors() 314 test_graphdata_getnegsampledneighbors() 315 test_graphdata_graphinfo() 316 test_graphdata_generatordataset() 317 test_graphdata_randomwalkdefault() 318 test_graphdata_randomwalk() 319 test_graphdata_getedgefeature() 320 test_graphdata_getedgesfromnodes() 321 test_graphdata_getnodefeature_invalidcase() 322 test_graphdata_getedgefeature_invalidcase() 323