• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Aggregator."""
16import mindspore.nn as nn
17from mindspore import Tensor, Parameter
18from mindspore._checkparam import Validator
19from mindspore._extends import cell_attr_register
20from mindspore.common.initializer import initializer
21from mindspore.nn.layer.activation import get_activation
22from mindspore.ops import functional as F
23from mindspore.ops import operations as P
24
25
26class GNNFeatureTransform(nn.Cell):
27    r"""
28    The GNN featuren transform layer for input.
29
30    Applies linear transformation for the input feature. This layer implements the operation as:
31
32    .. math::
33        \text{outputs} = \text{inputs} * \text{kernel} + \text{bias},
34
35    where :math:`\text{activation}` is the activation function passed as the activation
36    argument (if passed in),:math:`\text{activation}` is a weight matrix with the same
37    data type as the inputs created by the layer, and :math:`\text{bias}` is a bias vector
38    with the same data type as the inputs created by the layer (only if has_bias is True).
39
40    Args:
41        in_channels (int): The number of channels in the input space.
42        out_channels (int): The number of channels in the output space.
43        weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
44            is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
45        bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
46            same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
47        has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
48
49    Raises:
50        ValueError: If weight_init or bias_init shape is incorrect.
51
52    Inputs:
53        - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(*B, N, C)`,
54        where :math:`*B` represents the batch size which can be multidimensional, :math:`N` and :math:`C` are the
55        size of the last two dimensions. If `transpose_a` is True, its shape should be :math:`(*B, C, N)`.
56
57    Outputs:
58        Tensor, the shape of the output tensor is :math:`(*B, N, M)`.
59
60    Examples:
61        >>> net = nn.GNNFeatureTransform(3, 4)
62        >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
63        >>> net(input)
64        [[ 2.5246444   2.2738023   0.5711005  -3.9399147 ]
65         [ 1.0739875   4.0155234   0.94188046 -5.459526  ]]
66    """
67
68    @cell_attr_register
69    def __init__(self,
70                 in_channels,
71                 out_channels,
72                 weight_init='normal',
73                 bias_init='zeros',
74                 has_bias=True):
75        super(GNNFeatureTransform, self).__init__()
76        self.in_channels = Validator.check_positive_int(in_channels)
77        self.out_channels = Validator.check_positive_int(out_channels)
78        self.has_bias = Validator.check_bool(has_bias)
79
80        if isinstance(weight_init, Tensor):
81            if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
82                    weight_init.shape[1] != in_channels:
83                raise ValueError("weight_init shape error")
84
85        self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
86
87        if self.has_bias:
88            if isinstance(bias_init, Tensor):
89                if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
90                    raise ValueError("bias_init shape error")
91
92            self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
93
94        self.matmul = P.MatMul(transpose_b=True)
95        self.bias_add = P.BiasAdd()
96
97    def construct(self, x):
98        tensor_shape = F.shape(x)
99        input_feature = F.reshape(x, (tensor_shape[0] * tensor_shape[1], tensor_shape[2]))
100        output = self.matmul(input_feature, self.weight)
101        if self.has_bias:
102            output = self.bias_add(output, self.bias)
103        output = F.reshape(output, (tensor_shape[0], tensor_shape[1], self.out_channels))
104        return output
105
106    def extend_repr(self):
107        s = 'in_channels={}, out_channels={}'.format(self.in_channels, self.out_channels)
108        if self.has_bias:
109            s += ', has_bias={}'.format(self.has_bias)
110        return s
111
112
113class _BaseAggregator(nn.Cell):
114    """
115    Base Aggregator of GNN
116
117    Args:
118        feature_in_dim (int): Node or edge input feature dim.
119        feature_out_dim (int): Node or edge outpout feature dim.
120        use_fc (bool): Specifies whether a linear transformation before message is aggregated. Default: True
121        weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
122            is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
123        bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
124            same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
125        has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
126        dropout_ratio (float): The keep rate of dropout layer, greater than 0 and less equal than 1. Default: None.
127        activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
128
129    Examples:
130        >>> class MyAggregator(_BaseAggregator):
131        >>>    def __init__(self):
132        >>>        super(MyAggregator, self).__init__(self, feature_in_dim, feature_out_dim)
133        >>>        self.reduce_mean = P.ReduceSum()
134        >>>
135        >>>    def construct(self, x):
136        >>>        return self.reduce_mean(x, 1)
137    """
138
139    def __init__(self,
140                 feature_in_dim,
141                 feature_out_dim,
142                 use_fc=True,
143                 weight_init="normal",
144                 bias_init="zeros",
145                 has_bias=True,
146                 dropout_ratio=None,
147                 activation=None):
148        super(_BaseAggregator, self).__init__()
149        self.in_dim = feature_in_dim
150        self.out_dim = feature_out_dim
151        self.use_fc = use_fc
152        if self.use_fc:
153            self.weight_init = weight_init
154            self.bias_init = bias_init
155            self.has_bias = has_bias
156            self.fc = GNNFeatureTransform(self.in_dim,
157                                          self.out_dim,
158                                          weight_init=self.weight_init,
159                                          bias_init=self.bias_init,
160                                          has_bias=self.has_bias)
161        self.dropout_ratio = dropout_ratio
162        if self.dropout_ratio is not None:
163            self.dropout = nn.Dropout(keep_prob=self.dropout_ratio)
164        self.dropout_flag = self.dropout_ratio is not None
165        self.activation = get_activation(activation)
166        self.activation_flag = self.activation is not None
167
168    def construct(self, **kward):
169        """Must be overridden by all subclasses."""
170        raise NotImplementedError
171
172
173class MeanAggregator(_BaseAggregator):
174    """
175    Mean Aggregator of GNN
176
177    Args:
178        feature_in_dim (int): Node or edge input feature dim.
179        feature_out_dim (int): Node or edge outpout feature dim.
180        use_fc (bool): Specifies whether a linear transformation before message is aggregated. Default: True
181        weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
182            is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
183        bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
184            same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
185        has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
186        dropout_ratio (float): The keep rate of dropout layer, greater than 0 and less equal than 1. Default: None.
187        activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
188
189    Examples:
190        >>> net = MeanAggregator(32, 64, activation="relu", dropout=0.5)
191        >>> input_data = Tensor(np.array(np.random.rand(32, 3, 32), dtypy=np.float32))
192        >>> output = net(input_data)
193    """
194
195    def __init__(self,
196                 feature_in_dim,
197                 feature_out_dim,
198                 use_fc=True,
199                 weight_init="normal",
200                 bias_init="zeros",
201                 has_bias=True,
202                 dropout_ratio=None,
203                 activation=None):
204        super(MeanAggregator, self).__init__(
205            feature_in_dim,
206            feature_out_dim,
207            use_fc,
208            weight_init,
209            bias_init,
210            has_bias,
211            dropout_ratio,
212            activation)
213        self.reduce_mean = P.ReduceMean(keep_dims=False)
214
215    def construct(self, input_feature):
216        if self.use_fc:
217            input_feature = self.fc(input_feature)
218        if self.dropout_flag:
219            input_feature = self.dropout(input_feature)
220        if self.activation_flag:
221            input_feature = self.activation(input_feature)
222        output_feature = self.reduce_mean(input_feature, 1)
223        return output_feature
224
225
226class AttentionHead(nn.Cell):
227    """
228    Attention Head for Graph Attention Networks.
229
230    Args:
231        in_channel (int): The number of input channel, input feature dim.
232        out_channel (int): The number of output channel, output feature dim.
233        in_drop_ratio (float): Input feature dropout ratio, default 0.0.
234        coef_drop_ratio (float): Coefficient dropout ratio, default 0.0.
235        residual (bool): Whether to use residual connection, default False.
236        coef_activation (Cell): The attention coefficient activation function,
237            default nn.LeakyReLU().
238        activation (Cell): The output activation function, default nn.ELU().
239
240    Inputs:
241        - **input_feature** (Tensor) - Tensor of shape : (batch_size, num_nodes, feature_dim).
242        - **bias_mat** (Tensor) - Tensor of shape : (batch_size, num_nodes, num_nodes).
243
244    Examples:
245        >>> head = AttentionHead(1433,
246                                 8,
247                                 in_drop_ratio=0.6,
248                                 coef_drop_ratio=0.6,
249                                 residual=False)
250        >>> input_data = Tensor(np.array(np.random.rand(1, 2708, 1433), dtypy=np.float32))
251        >>> output = net(input_data)
252    """
253
254    def __init__(self,
255                 in_channel,
256                 out_channel,
257                 in_drop_ratio=0.0,
258                 coef_drop_ratio=0.0,
259                 residual=False,
260                 coef_activation=nn.LeakyReLU(),
261                 activation=nn.ELU()):
262        super(AttentionHead, self).__init__()
263        self.in_channel = Validator.check_positive_int(in_channel)
264        self.out_channel = Validator.check_positive_int(out_channel)
265        self.in_drop_ratio = in_drop_ratio
266        self.in_drop = nn.Dropout(keep_prob=1 - in_drop_ratio)
267        self.in_drop_2 = nn.Dropout(keep_prob=1 - in_drop_ratio)
268        self.feature_transform = GNNFeatureTransform(
269            in_channels=self.in_channel,
270            out_channels=self.out_channel,
271            has_bias=False)
272
273        self.f_1_transform = GNNFeatureTransform(
274            in_channels=self.out_channel,
275            out_channels=1)
276        self.f_2_transform = GNNFeatureTransform(
277            in_channels=self.out_channel,
278            out_channels=1)
279        self.softmax = nn.Softmax()
280
281        self.coef_drop = nn.Dropout(keep_prob=1 - coef_drop_ratio)
282        self.batch_matmul = P.BatchMatMul()
283        self.bias_add = P.BiasAdd()
284        self.bias = Parameter(initializer('zeros', self.out_channel), name='bias')
285        self.residual = Validator.check_bool(residual)
286        if self.residual:
287            if in_channel != out_channel:
288                self.residual_transform_flag = True
289                self.residual_transform = GNNFeatureTransform(
290                    in_channels=self.in_channel,
291                    out_channels=self.out_channel)
292            else:
293                self.residual_transform = None
294        self.coef_activation = coef_activation
295        self.activation = activation
296
297    def construct(self, input_feature, bias_mat):
298        input_feature = self.in_drop(input_feature)
299
300        feature = self.feature_transform(input_feature)
301        # self attention following the author
302        f_1 = self.f_1_transform(feature)
303        f_2 = self.f_2_transform(feature)
304        logits = f_1 + P.Transpose()(f_2, (0, 2, 1))
305        logits = self.coef_activation(logits) + bias_mat
306        coefs = self.softmax(logits)
307
308        coefs = self.coef_drop(coefs)
309        feature = self.in_drop_2(feature)
310
311        ret = self.batch_matmul(coefs, feature)
312        ret = P.Squeeze(0)(ret)
313        ret = self.bias_add(ret, self.bias)
314        ret = P.ExpandDims()(ret, 0)
315        # residual connection
316        if self.residual:
317            if self.residual_transform_flag:
318                res = self.residual_transform(input_feature)
319                ret = ret + res
320            else:
321                ret = ret + input_feature
322        # activation
323        if self.activation is not None:
324            ret = self.activation(ret)
325        return ret
326
327
328class AttentionAggregator(nn.Cell):
329    """
330    Attention Head for Graph Attention Networks,can be regarded as one
331        GAT layer.
332
333    Args:
334        in_channel (int): Input channel.
335        out_channel (int): Output channel.
336        num_heads (int): Number of attention heads for this layer, default 1.
337        in_drop_ratio (float): Input feature dropout ratio, default 0.0.
338        coef_drop_ratio (float): Coefficient dropout ratio, default 0.0.
339        activation (Cell): The output activation function, default nn.ELU().
340        residual (bool): Whether to use residual connection, default False.
341        output_transform (str['concat', 'sum']): output transform for a layer,
342            default 'concat'
343
344    Inputs:
345        - **input_feature** (Tensor) - Tensor of shape : (batch_size, num_nodes, feature_dim).
346        - **bias_mat** (Tensor) - Tensor of shape : (batch_size, num_nodes, num_nodes).
347
348    Examples:
349        >>> input_data = Tensor(np.array(np.random.rand(1, 2708, 1433), dtype=np.float32))
350        >>> biases = Tensor(np.array(np.random.rand(1, 2708, 2708), dtype=np.float32))
351        >>> net = AttentionAggregator(1433,
352                                      8,
353                                      8)
354        >>> net(input_data, biases)
355    """
356
357    def __init__(self,
358                 in_channels,
359                 out_channels,
360                 num_heads=1,
361                 in_drop=0.0,
362                 coef_drop=0.0,
363                 activation=nn.ELU(),
364                 residual=False,
365                 output_transform='concat'):
366        super(AttentionAggregator, self).__init__()
367        self.num_heads = num_heads
368        self.attns = []
369        for _ in range(num_heads):
370            self.attns.append(AttentionHead(in_channels,
371                                            out_channels,
372                                            in_drop_ratio=in_drop,
373                                            coef_drop_ratio=coef_drop,
374                                            activation=activation,
375                                            residual=residual))
376        self.attns = nn.layer.CellList(self.attns)
377        if output_transform == 'concat':
378            self.out_trans = P.Concat(-1)
379        elif output_transform == 'sum':
380            self.out_trans = P.AddN()
381        else:
382            raise ValueError
383
384    def construct(self, input_data, bias_mat):
385        res = ()
386        for i in range(self.num_heads):
387            res += (self.attns[i](input_data, bias_mat),)
388        return self.out_trans(res)
389