• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16import os
17import random
18import time
19from multiprocessing import Process
20import numpy as np
21import mindspore.dataset as ds
22from mindspore import log as logger
23from mindspore.dataset.engine import SamplingStrategy
24from mindspore.dataset.engine import OutputFormat
25
26DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
27
28
29def graphdata_startserver(server_port):
30    """
31    start graphdata server
32    """
33    logger.info('test start server.\n')
34    ds.GraphData(DATASET_FILE, 1, 'server', port=server_port)
35
36
37class RandomBatchedSampler(ds.Sampler):
38    # RandomBatchedSampler generate random sequence without replacement in a batched manner
39    def __init__(self, index_range, num_edges_per_sample):
40        super().__init__()
41        self.index_range = index_range
42        self.num_edges_per_sample = num_edges_per_sample
43
44    def __iter__(self):
45        indices = [i+1 for i in range(self.index_range)]
46        # Reset random seed here if necessary
47        # random.seed(0)
48        random.shuffle(indices)
49        for i in range(0, self.index_range, self.num_edges_per_sample):
50            # Drop reminder
51            if i + self.num_edges_per_sample <= self.index_range:
52                yield indices[i: i + self.num_edges_per_sample]
53
54
55class GNNGraphDataset():
56    def __init__(self, g, batch_num):
57        self.g = g
58        self.batch_num = batch_num
59
60    def __len__(self):
61        # Total sample size of GNN dataset
62        # In this case, the size should be total_num_edges/num_edges_per_sample
63        return self.g.graph_info()['edge_num'][0] // self.batch_num
64
65    def __getitem__(self, index):
66        # index will be a list of indices yielded from RandomBatchedSampler
67        # Fetch edges/nodes/samples/features based on indices
68        nodes = self.g.get_nodes_from_edges(index.astype(np.int32))
69        nodes = nodes[:, 0]
70        neg_nodes = self.g.get_neg_sampled_neighbors(
71            node_list=nodes, neg_neighbor_num=3, neg_neighbor_type=1)
72        nodes_neighbors = self.g.get_sampled_neighbors(node_list=nodes, neighbor_nums=[
73            2, 2], neighbor_types=[2, 1], strategy=SamplingStrategy.RANDOM)
74        neg_nodes_neighbors = self.g.get_sampled_neighbors(node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2],
75                                                           neighbor_types=[2, 1], strategy=SamplingStrategy.EDGE_WEIGHT)
76        nodes_neighbors_features = self.g.get_node_feature(
77            node_list=nodes_neighbors, feature_types=[2, 3])
78        neg_neighbors_features = self.g.get_node_feature(
79            node_list=neg_nodes_neighbors, feature_types=[2, 3])
80        return nodes_neighbors, neg_nodes_neighbors, nodes_neighbors_features[0], neg_neighbors_features[1]
81
82
83def test_graphdata_distributed():
84    """
85    Test distributed
86    """
87    ASAN = os.environ.get('ASAN_OPTIONS')
88    if ASAN:
89        logger.info("skip the graphdata distributed when asan mode")
90        return
91
92    logger.info('test distributed.\n')
93
94    server_port = random.randint(10000, 60000)
95
96    p1 = Process(target=graphdata_startserver, args=(server_port,))
97    p1.start()
98    time.sleep(5)
99
100    g = ds.GraphData(DATASET_FILE, 1, 'client', port=server_port)
101    nodes = g.get_all_nodes(1)
102    assert nodes.tolist() == [101, 102, 103, 104, 105, 106, 107, 108, 109, 110]
103    row_tensor = g.get_node_feature(nodes.tolist(), [1, 2, 3])
104    assert row_tensor[0].tolist() == [[0, 1, 0, 0, 0], [1, 0, 0, 0, 1], [0, 0, 1, 1, 0], [0, 0, 0, 0, 0],
105                                      [1, 1, 0, 1, 0], [0, 0, 0, 0, 1], [0, 1, 0, 0, 0], [0, 0, 0, 1, 1],
106                                      [0, 1, 1, 0, 0], [0, 1, 0, 1, 0]]
107    assert row_tensor[2].tolist() == [1, 2, 3, 1, 4, 3, 5, 3, 5, 4]
108
109    neighbor_normal = g.get_all_neighbors(nodes, 2, OutputFormat.NORMAL)
110    assert neighbor_normal.shape == (10, 6)
111    neighbor_coo = g.get_all_neighbors(nodes, 2, OutputFormat.COO)
112    assert neighbor_coo.shape == (20, 2)
113    offset_table, neighbor_csr = g.get_all_neighbors(nodes, 2, OutputFormat.CSR)
114    assert offset_table.shape == (10,)
115    assert neighbor_csr.shape == (20,)
116
117    edges = g.get_all_edges(0)
118    assert edges.tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
119                              21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40]
120    features = g.get_edge_feature(edges, [1, 2])
121    assert features[0].tolist() == [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0,
122                                    0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]
123
124    nodes_pair_list = [(101, 201), (103, 207), (204, 105), (108, 208), (110, 210), (202, 102), (201, 107), (208, 108)]
125    edges = g.get_edges_from_nodes(nodes_pair_list)
126    assert edges.tolist() == [1, 9, 31, 17, 20, 25, 34, 37]
127
128    batch_num = 2
129    edge_num = g.graph_info()['edge_num'][0]
130    out_column_names = ["neighbors", "neg_neighbors", "neighbors_features", "neg_neighbors_features"]
131    dataset = ds.GeneratorDataset(source=GNNGraphDataset(g, batch_num), column_names=out_column_names,
132                                  sampler=RandomBatchedSampler(edge_num, batch_num), num_parallel_workers=4,
133                                  python_multiprocessing=False)
134    dataset = dataset.repeat(2)
135    itr = dataset.create_dict_iterator(num_epochs=1, output_numpy=True)
136    i = 0
137    for data in itr:
138        assert data['neighbors'].shape == (2, 7)
139        assert data['neg_neighbors'].shape == (6, 7)
140        assert data['neighbors_features'].shape == (2, 7)
141        assert data['neg_neighbors_features'].shape == (6, 7)
142        i += 1
143    assert i == 40
144
145
146if __name__ == '__main__':
147    test_graphdata_distributed()
148