• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Graph Attention Networks."""
16import mindspore.nn as nn
17from mindspore._checkparam import Validator
18
19from aggregator import AttentionAggregator
20
21
22class GAT(nn.Cell):
23    """
24    Graph Attention Network
25
26    Args:
27        ftr_dims (int): Initial feature dimensions.
28        num_class (int): Num of class to identify.
29        num_nodes (int): Num of nodes in this graph.
30        hidden_units (list[int]): Num of hidden units at each layer.
31        num_heads (list[int]): Num of heads at each layer.
32        attn_drop (float): Drop out ratio of attention coefficient,
33            default 0.0.
34        ftr_drop (float): Drop out ratio of feature, default 0.0.
35        activation (Cell): Activation Function for output layer, default
36            nn.Elu().
37        residual (bool): Whether to use residual connection between
38            intermediate layers, default False.
39
40    Examples:
41        >>> ft_sizes = 1433
42        >>> num_class = 7
43        >>> num_nodes = 2708
44        >>> hid_units = [8]
45        >>> n_heads = [8, 1]
46        >>> activation = nn.ELU()
47        >>> residual = False
48        >>> input_data = Tensor(
49                np.array(np.random.rand(1, 2708, 1433), dtype=np.float32))
50        >>> biases = Tensor(np.array(np.random.rand(1, 2708, 2708), dtype=np.float32))
51        >>> net = GAT(ft_sizes,
52                      num_class,
53                      num_nodes,
54                      hidden_units=hid_units,
55                      num_heads=n_heads,
56                      attn_drop=0.6,
57                      ftr_drop=0.6,
58                      activation=activation,
59                      residual=residual)
60        >>> output = net(input_data, biases)
61    """
62
63    def __init__(self,
64                 ftr_dims,
65                 num_class,
66                 num_nodes,
67                 hidden_units,
68                 num_heads,
69                 attn_drop=0.0,
70                 ftr_drop=0.0,
71                 activation=nn.ELU(),
72                 residual=False):
73        super(GAT, self).__init__()
74        self.ftr_dims = Validator.check_positive_int(ftr_dims)
75        self.num_class = Validator.check_positive_int(num_class)
76        self.num_nodes = Validator.check_positive_int(num_nodes)
77        self.hidden_units = hidden_units
78        self.num_heads = num_heads
79        self.attn_drop = attn_drop
80        self.ftr_drop = ftr_drop
81        self.activation = activation
82        self.residual = Validator.check_bool(residual)
83        self.layers = []
84        # first layer
85        self.layers.append(AttentionAggregator(
86            self.ftr_dims,
87            self.hidden_units[0],
88            self.num_heads[0],
89            self.ftr_drop,
90            self.attn_drop,
91            self.activation,
92            residual=False))
93        # intermediate layer
94        for i in range(1, len(self.hidden_units)):
95            self.layers.append(AttentionAggregator(
96                self.hidden_units[i-1]*self.num_heads[i-1],
97                self.hidden_units[i],
98                self.num_heads[i],
99                self.ftr_drop,
100                self.attn_drop,
101                self.activation,
102                residual=self.residual))
103        # output layer
104        self.layers.append(AttentionAggregator(
105            self.hidden_units[-1]*self.num_heads[-2],
106            self.num_class,
107            self.num_heads[-1],
108            self.ftr_drop,
109            self.attn_drop,
110            activation=None,
111            residual=False,
112            output_transform='sum'))
113        self.layers = nn.layer.CellList(self.layers)
114
115    def construct(self, input_data, bias_mat):
116        for cell in self.layers:
117            input_data = cell(input_data, bias_mat)
118        return input_data/self.num_heads[-1]
119