• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import random
16import pytest
17import numpy as np
18import mindspore.dataset as ds
19from mindspore import log as logger
20from mindspore.dataset.engine import SamplingStrategy
21from mindspore.dataset.engine import OutputFormat
22
23DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
24SOCIAL_DATA_FILE = "../data/mindrecord/testGraphData/sns"
25
26
27def test_graphdata_getfullneighbor():
28    """
29    Test get all neighbors
30    """
31    logger.info('test get all neighbors.\n')
32    g = ds.GraphData(DATASET_FILE, 2)
33    nodes = g.get_all_nodes(1)
34    assert len(nodes) == 10
35    neighbor = g.get_all_neighbors(nodes, 2)
36    assert neighbor.shape == (10, 6)
37    row_tensor = g.get_node_feature(neighbor.tolist(), [2, 3])
38    assert row_tensor[0].shape == (10, 6)
39
40
41def test_graphdata_getallneighbors_special_format():
42    """
43    Test get all neighbors with special format
44    """
45    logger.info('test get all neighbors with special format.\n')
46    g = ds.GraphData(DATASET_FILE, 2)
47    nodes = g.get_all_nodes(1)
48    assert len(nodes) == 10
49
50    neighbor_coo = g.get_all_neighbors(nodes, 2, OutputFormat.COO)
51    assert neighbor_coo.shape == (20, 2)
52
53    offset_table, neighbor_csr = g.get_all_neighbors(nodes, 2, OutputFormat.CSR)
54    assert offset_table.shape == (10,)
55    assert neighbor_csr.shape == (20,)
56
57
58def test_graphdata_getnodefeature_input_check():
59    """
60    Test get node feature input check
61    """
62    logger.info('test getnodefeature input check.\n')
63    g = ds.GraphData(DATASET_FILE)
64    with pytest.raises(TypeError):
65        input_list = [1, [1, 1]]
66        g.get_node_feature(input_list, [1])
67
68    with pytest.raises(TypeError):
69        input_list = [[1, 1], 1]
70        g.get_node_feature(input_list, [1])
71
72    with pytest.raises(TypeError):
73        input_list = [[1, 1], [1, 1, 1]]
74        g.get_node_feature(input_list, [1])
75
76    with pytest.raises(TypeError):
77        input_list = [[1, 1, 1], [1, 1]]
78        g.get_node_feature(input_list, [1])
79
80    with pytest.raises(TypeError):
81        input_list = [[1, 1], [1, [1, 1]]]
82        g.get_node_feature(input_list, [1])
83
84    with pytest.raises(TypeError):
85        input_list = [[1, 1], [[1, 1], 1]]
86        g.get_node_feature(input_list, [1])
87
88    with pytest.raises(TypeError):
89        input_list = [[1, 1], [1, 1]]
90        g.get_node_feature(input_list, 1)
91
92    with pytest.raises(TypeError):
93        input_list = [[1, 0.1], [1, 1]]
94        g.get_node_feature(input_list, 1)
95
96    with pytest.raises(TypeError):
97        input_list = np.array([[1, 0.1], [1, 1]])
98        g.get_node_feature(input_list, 1)
99
100    with pytest.raises(TypeError):
101        input_list = [[1, 1], [1, 1]]
102        g.get_node_feature(input_list, ["a"])
103
104    with pytest.raises(TypeError):
105        input_list = [[1, 1], [1, 1]]
106        g.get_node_feature(input_list, [1, "a"])
107
108
109def test_graphdata_getsampledneighbors():
110    """
111    Test sampled neighbors
112    """
113    logger.info('test get sampled neighbors.\n')
114    g = ds.GraphData(DATASET_FILE, 1)
115    edges = g.get_all_edges(0)
116    nodes = g.get_nodes_from_edges(edges)
117    assert len(nodes) == 40
118    neighbor = g.get_sampled_neighbors(
119        np.unique(nodes[0:21, 0]), [2, 3], [2, 1], SamplingStrategy.RANDOM)
120    assert neighbor.shape == (10, 9)
121    neighbor = g.get_sampled_neighbors(
122        np.unique(nodes[0:21, 0]), [2, 3], [2, 1], SamplingStrategy.EDGE_WEIGHT)
123    assert neighbor.shape == (10, 9)
124
125
126def test_graphdata_getnegsampledneighbors():
127    """
128    Test neg sampled neighbors
129    """
130    logger.info('test get negative sampled neighbors.\n')
131    g = ds.GraphData(DATASET_FILE, 2)
132    nodes = g.get_all_nodes(1)
133    assert len(nodes) == 10
134    neighbor = g.get_neg_sampled_neighbors(nodes, 5, 2)
135    assert neighbor.shape == (10, 6)
136
137
138def test_graphdata_graphinfo():
139    """
140    Test graph info
141    """
142    logger.info('test graph info.\n')
143    g = ds.GraphData(DATASET_FILE, 2)
144    graph_info = g.graph_info()
145    assert graph_info['node_type'] == [1, 2]
146    assert graph_info['edge_type'] == [0]
147    assert graph_info['node_num'] == {1: 10, 2: 10}
148    assert graph_info['edge_num'] == {0: 40}
149    assert graph_info['node_feature_type'] == [1, 2, 3, 4]
150    assert graph_info['edge_feature_type'] == [1, 2]
151
152
153class RandomBatchedSampler(ds.Sampler):
154    # RandomBatchedSampler generate random sequence without replacement in a batched manner
155    def __init__(self, index_range, num_edges_per_sample):
156        super().__init__()
157        self.index_range = index_range
158        self.num_edges_per_sample = num_edges_per_sample
159
160    def __iter__(self):
161        indices = [i+1 for i in range(self.index_range)]
162        # Reset random seed here if necessary
163        # random.seed(0)
164        random.shuffle(indices)
165        for i in range(0, self.index_range, self.num_edges_per_sample):
166            # Drop reminder
167            if i + self.num_edges_per_sample <= self.index_range:
168                yield indices[i: i + self.num_edges_per_sample]
169
170
171class GNNGraphDataset():
172    def __init__(self, g, batch_num):
173        self.g = g
174        self.batch_num = batch_num
175
176    def __len__(self):
177        # Total sample size of GNN dataset
178        # In this case, the size should be total_num_edges/num_edges_per_sample
179        return self.g.graph_info()['edge_num'][0] // self.batch_num
180
181    def __getitem__(self, index):
182        # index will be a list of indices yielded from RandomBatchedSampler
183        # Fetch edges/nodes/samples/features based on indices
184        nodes = self.g.get_nodes_from_edges(index.astype(np.int32))
185        nodes = nodes[:, 0]
186        neg_nodes = self.g.get_neg_sampled_neighbors(
187            node_list=nodes, neg_neighbor_num=3, neg_neighbor_type=1)
188        nodes_neighbors = self.g.get_sampled_neighbors(node_list=nodes, neighbor_nums=[
189            2, 2], neighbor_types=[2, 1])
190        neg_nodes_neighbors = self.g.get_sampled_neighbors(
191            node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2], neighbor_types=[2, 2])
192        nodes_neighbors_features = self.g.get_node_feature(
193            node_list=nodes_neighbors, feature_types=[2, 3])
194        neg_neighbors_features = self.g.get_node_feature(
195            node_list=neg_nodes_neighbors, feature_types=[2, 3])
196        return nodes_neighbors, neg_nodes_neighbors, nodes_neighbors_features[0], neg_neighbors_features[1]
197
198
199def test_graphdata_generatordataset():
200    """
201    Test generator dataset
202    """
203    logger.info('test generator dataset.\n')
204
205    #reduce memory required by disabling the shm optimization
206    mem_original = ds.config.get_enable_shared_mem()
207    ds.config.set_enable_shared_mem(False)
208
209    g = ds.GraphData(DATASET_FILE)
210    batch_num = 2
211    edge_num = g.graph_info()['edge_num'][0]
212    out_column_names = ["neighbors", "neg_neighbors", "neighbors_features", "neg_neighbors_features"]
213    dataset = ds.GeneratorDataset(source=GNNGraphDataset(g, batch_num), column_names=out_column_names,
214                                  sampler=RandomBatchedSampler(edge_num, batch_num), num_parallel_workers=4)
215    dataset = dataset.repeat(2)
216    itr = dataset.create_dict_iterator(num_epochs=1, output_numpy=True)
217    i = 0
218    for data in itr:
219        assert data['neighbors'].shape == (2, 7)
220        assert data['neg_neighbors'].shape == (6, 7)
221        assert data['neighbors_features'].shape == (2, 7)
222        assert data['neg_neighbors_features'].shape == (6, 7)
223        i += 1
224    assert i == 40
225
226    ds.config.set_enable_shared_mem(mem_original)
227
228def test_graphdata_randomwalkdefault():
229    """
230    Test random walk defaults
231    """
232    logger.info('test randomwalk with default parameters.\n')
233    g = ds.GraphData(SOCIAL_DATA_FILE, 1)
234    nodes = g.get_all_nodes(1)
235    assert len(nodes) == 33
236
237    meta_path = [1 for _ in range(39)]
238    walks = g.random_walk(nodes, meta_path)
239    assert walks.shape == (33, 40)
240
241
242def test_graphdata_randomwalk():
243    """
244    Test random walk
245    """
246    logger.info('test random walk with given parameters.\n')
247    g = ds.GraphData(SOCIAL_DATA_FILE, 1)
248    nodes = g.get_all_nodes(1)
249    assert len(nodes) == 33
250
251    meta_path = [1 for _ in range(39)]
252    walks = g.random_walk(nodes, meta_path, 2.0, 0.5, -1)
253    assert walks.shape == (33, 40)
254
255
256def test_graphdata_getedgefeature():
257    """
258    Test get edge feature
259    """
260    logger.info('test get_edge_feature.\n')
261    g = ds.GraphData(DATASET_FILE)
262    edges = g.get_all_edges(0)
263    features = g.get_edge_feature(edges, [1, 2])
264    assert features[0].shape == (40,)
265    assert features[1].shape == (40,)
266
267
268def test_graphdata_getedgefeature_invalidcase():
269    """
270    Test get edge feature with invalid edge id, 0 should be returned for those invalid edge id in correct index
271    """
272    logger.info('test get_edge_feature.\n')
273    g = ds.GraphData(DATASET_FILE)
274    edges = g.get_all_edges(0)
275    edges[-6] = -1
276    features = g.get_edge_feature(edges, [1, 2])
277    assert features[0].shape == (40,)
278    assert features[1].shape == (40,)
279    assert features[0][-6] == 0
280    assert features[1][-6] == 0.
281
282
283def test_graphdata_getnodefeature_invalidcase():
284    """
285    Test get node feature with invalid node id, 0 should be returned for those invalid node id in correct index
286    """
287    logger.info('test get_node_feature.\n')
288    g = ds.GraphData(DATASET_FILE)
289    nodes = g.get_all_nodes(node_type=1)
290    nodes[5] = -1
291    features = g.get_node_feature(node_list=nodes, feature_types=[2, 3])
292    assert features[0].shape == (10,)
293    assert features[1].shape == (10,)
294    assert features[0][5] == 0.
295    assert features[1][5] == 0
296
297
298def test_graphdata_getedgesfromnodes():
299    """
300    Test get edges from nodes
301    """
302    logger.info('test get_edges_from_nodes\n')
303    g = ds.GraphData(DATASET_FILE)
304
305    nodes_pair_list = [(101, 201), (103, 207), (204, 105), (108, 208), (110, 210), (210, 110)]
306    edges = g.get_edges_from_nodes(node_list=nodes_pair_list)
307    assert edges.tolist() == [1, 9, 31, 17, 20, 40]
308
309
310if __name__ == '__main__':
311    test_graphdata_getfullneighbor()
312    test_graphdata_getnodefeature_input_check()
313    test_graphdata_getsampledneighbors()
314    test_graphdata_getnegsampledneighbors()
315    test_graphdata_graphinfo()
316    test_graphdata_generatordataset()
317    test_graphdata_randomwalkdefault()
318    test_graphdata_randomwalk()
319    test_graphdata_getedgefeature()
320    test_graphdata_getedgesfromnodes()
321    test_graphdata_getnodefeature_invalidcase()
322    test_graphdata_getedgefeature_invalidcase()
323