1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16import os 17import random 18import time 19from multiprocessing import Process 20import numpy as np 21import mindspore.dataset as ds 22from mindspore import log as logger 23from mindspore.dataset.engine import SamplingStrategy 24from mindspore.dataset.engine import OutputFormat 25 26DATASET_FILE = "../data/mindrecord/testGraphData/testdata" 27 28 29def graphdata_startserver(server_port): 30 """ 31 start graphdata server 32 """ 33 logger.info('test start server.\n') 34 ds.GraphData(DATASET_FILE, 1, 'server', port=server_port) 35 36 37class RandomBatchedSampler(ds.Sampler): 38 # RandomBatchedSampler generate random sequence without replacement in a batched manner 39 def __init__(self, index_range, num_edges_per_sample): 40 super().__init__() 41 self.index_range = index_range 42 self.num_edges_per_sample = num_edges_per_sample 43 44 def __iter__(self): 45 indices = [i+1 for i in range(self.index_range)] 46 # Reset random seed here if necessary 47 # random.seed(0) 48 random.shuffle(indices) 49 for i in range(0, self.index_range, self.num_edges_per_sample): 50 # Drop reminder 51 if i + self.num_edges_per_sample <= self.index_range: 52 yield indices[i: i + self.num_edges_per_sample] 53 54 55class GNNGraphDataset(): 56 def __init__(self, g, batch_num): 57 self.g = g 58 self.batch_num = batch_num 59 60 def __len__(self): 61 # Total sample size of GNN dataset 62 # In this case, the size should be total_num_edges/num_edges_per_sample 63 return self.g.graph_info()['edge_num'][0] // self.batch_num 64 65 def __getitem__(self, index): 66 # index will be a list of indices yielded from RandomBatchedSampler 67 # Fetch edges/nodes/samples/features based on indices 68 nodes = self.g.get_nodes_from_edges(index.astype(np.int32)) 69 nodes = nodes[:, 0] 70 neg_nodes = self.g.get_neg_sampled_neighbors( 71 node_list=nodes, neg_neighbor_num=3, neg_neighbor_type=1) 72 nodes_neighbors = self.g.get_sampled_neighbors(node_list=nodes, neighbor_nums=[ 73 2, 2], neighbor_types=[2, 1], strategy=SamplingStrategy.RANDOM) 74 neg_nodes_neighbors = self.g.get_sampled_neighbors(node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2], 75 neighbor_types=[2, 1], strategy=SamplingStrategy.EDGE_WEIGHT) 76 nodes_neighbors_features = self.g.get_node_feature( 77 node_list=nodes_neighbors, feature_types=[2, 3]) 78 neg_neighbors_features = self.g.get_node_feature( 79 node_list=neg_nodes_neighbors, feature_types=[2, 3]) 80 return nodes_neighbors, neg_nodes_neighbors, nodes_neighbors_features[0], neg_neighbors_features[1] 81 82 83def test_graphdata_distributed(): 84 """ 85 Test distributed 86 """ 87 ASAN = os.environ.get('ASAN_OPTIONS') 88 if ASAN: 89 logger.info("skip the graphdata distributed when asan mode") 90 return 91 92 logger.info('test distributed.\n') 93 94 server_port = random.randint(10000, 60000) 95 96 p1 = Process(target=graphdata_startserver, args=(server_port,)) 97 p1.start() 98 time.sleep(5) 99 100 g = ds.GraphData(DATASET_FILE, 1, 'client', port=server_port) 101 nodes = g.get_all_nodes(1) 102 assert nodes.tolist() == [101, 102, 103, 104, 105, 106, 107, 108, 109, 110] 103 row_tensor = g.get_node_feature(nodes.tolist(), [1, 2, 3]) 104 assert row_tensor[0].tolist() == [[0, 1, 0, 0, 0], [1, 0, 0, 0, 1], [0, 0, 1, 1, 0], [0, 0, 0, 0, 0], 105 [1, 1, 0, 1, 0], [0, 0, 0, 0, 1], [0, 1, 0, 0, 0], [0, 0, 0, 1, 1], 106 [0, 1, 1, 0, 0], [0, 1, 0, 1, 0]] 107 assert row_tensor[2].tolist() == [1, 2, 3, 1, 4, 3, 5, 3, 5, 4] 108 109 neighbor_normal = g.get_all_neighbors(nodes, 2, OutputFormat.NORMAL) 110 assert neighbor_normal.shape == (10, 6) 111 neighbor_coo = g.get_all_neighbors(nodes, 2, OutputFormat.COO) 112 assert neighbor_coo.shape == (20, 2) 113 offset_table, neighbor_csr = g.get_all_neighbors(nodes, 2, OutputFormat.CSR) 114 assert offset_table.shape == (10,) 115 assert neighbor_csr.shape == (20,) 116 117 edges = g.get_all_edges(0) 118 assert edges.tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 119 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40] 120 features = g.get_edge_feature(edges, [1, 2]) 121 assert features[0].tolist() == [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 122 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0] 123 124 nodes_pair_list = [(101, 201), (103, 207), (204, 105), (108, 208), (110, 210), (202, 102), (201, 107), (208, 108)] 125 edges = g.get_edges_from_nodes(nodes_pair_list) 126 assert edges.tolist() == [1, 9, 31, 17, 20, 25, 34, 37] 127 128 batch_num = 2 129 edge_num = g.graph_info()['edge_num'][0] 130 out_column_names = ["neighbors", "neg_neighbors", "neighbors_features", "neg_neighbors_features"] 131 dataset = ds.GeneratorDataset(source=GNNGraphDataset(g, batch_num), column_names=out_column_names, 132 sampler=RandomBatchedSampler(edge_num, batch_num), num_parallel_workers=4, 133 python_multiprocessing=False) 134 dataset = dataset.repeat(2) 135 itr = dataset.create_dict_iterator(num_epochs=1, output_numpy=True) 136 i = 0 137 for data in itr: 138 assert data['neighbors'].shape == (2, 7) 139 assert data['neg_neighbors'].shape == (6, 7) 140 assert data['neighbors_features'].shape == (2, 7) 141 assert data['neg_neighbors_features'].shape == (6, 7) 142 i += 1 143 assert i == 40 144 145 146if __name__ == '__main__': 147 test_graphdata_distributed() 148