1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Graph Attention Networks.""" 16import mindspore.nn as nn 17from mindspore._checkparam import Validator 18 19from aggregator import AttentionAggregator 20 21 22class GAT(nn.Cell): 23 """ 24 Graph Attention Network 25 26 Args: 27 ftr_dims (int): Initial feature dimensions. 28 num_class (int): Num of class to identify. 29 num_nodes (int): Num of nodes in this graph. 30 hidden_units (list[int]): Num of hidden units at each layer. 31 num_heads (list[int]): Num of heads at each layer. 32 attn_drop (float): Drop out ratio of attention coefficient, 33 default 0.0. 34 ftr_drop (float): Drop out ratio of feature, default 0.0. 35 activation (Cell): Activation Function for output layer, default 36 nn.Elu(). 37 residual (bool): Whether to use residual connection between 38 intermediate layers, default False. 39 40 Examples: 41 >>> ft_sizes = 1433 42 >>> num_class = 7 43 >>> num_nodes = 2708 44 >>> hid_units = [8] 45 >>> n_heads = [8, 1] 46 >>> activation = nn.ELU() 47 >>> residual = False 48 >>> input_data = Tensor( 49 np.array(np.random.rand(1, 2708, 1433), dtype=np.float32)) 50 >>> biases = Tensor(np.array(np.random.rand(1, 2708, 2708), dtype=np.float32)) 51 >>> net = GAT(ft_sizes, 52 num_class, 53 num_nodes, 54 hidden_units=hid_units, 55 num_heads=n_heads, 56 attn_drop=0.6, 57 ftr_drop=0.6, 58 activation=activation, 59 residual=residual) 60 >>> output = net(input_data, biases) 61 """ 62 63 def __init__(self, 64 ftr_dims, 65 num_class, 66 num_nodes, 67 hidden_units, 68 num_heads, 69 attn_drop=0.0, 70 ftr_drop=0.0, 71 activation=nn.ELU(), 72 residual=False): 73 super(GAT, self).__init__() 74 self.ftr_dims = Validator.check_positive_int(ftr_dims) 75 self.num_class = Validator.check_positive_int(num_class) 76 self.num_nodes = Validator.check_positive_int(num_nodes) 77 self.hidden_units = hidden_units 78 self.num_heads = num_heads 79 self.attn_drop = attn_drop 80 self.ftr_drop = ftr_drop 81 self.activation = activation 82 self.residual = Validator.check_bool(residual) 83 self.layers = [] 84 # first layer 85 self.layers.append(AttentionAggregator( 86 self.ftr_dims, 87 self.hidden_units[0], 88 self.num_heads[0], 89 self.ftr_drop, 90 self.attn_drop, 91 self.activation, 92 residual=False)) 93 # intermediate layer 94 for i in range(1, len(self.hidden_units)): 95 self.layers.append(AttentionAggregator( 96 self.hidden_units[i-1]*self.num_heads[i-1], 97 self.hidden_units[i], 98 self.num_heads[i], 99 self.ftr_drop, 100 self.attn_drop, 101 self.activation, 102 residual=self.residual)) 103 # output layer 104 self.layers.append(AttentionAggregator( 105 self.hidden_units[-1]*self.num_heads[-2], 106 self.num_class, 107 self.num_heads[-1], 108 self.ftr_drop, 109 self.attn_drop, 110 activation=None, 111 residual=False, 112 output_transform='sum')) 113 self.layers = nn.layer.CellList(self.layers) 114 115 def construct(self, input_data, bias_mat): 116 for cell in self.layers: 117 input_data = cell(input_data, bias_mat) 118 return input_data/self.num_heads[-1] 119