1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Aggregator.""" 16import mindspore.nn as nn 17from mindspore import Tensor, Parameter 18from mindspore._checkparam import Validator 19from mindspore._extends import cell_attr_register 20from mindspore.common.initializer import initializer 21from mindspore.nn.layer.activation import get_activation 22from mindspore.ops import functional as F 23from mindspore.ops import operations as P 24 25 26class GNNFeatureTransform(nn.Cell): 27 r""" 28 The GNN featuren transform layer for input. 29 30 Applies linear transformation for the input feature. This layer implements the operation as: 31 32 .. math:: 33 \text{outputs} = \text{inputs} * \text{kernel} + \text{bias}, 34 35 where :math:`\text{activation}` is the activation function passed as the activation 36 argument (if passed in),:math:`\text{activation}` is a weight matrix with the same 37 data type as the inputs created by the layer, and :math:`\text{bias}` is a bias vector 38 with the same data type as the inputs created by the layer (only if has_bias is True). 39 40 Args: 41 in_channels (int): The number of channels in the input space. 42 out_channels (int): The number of channels in the output space. 43 weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype 44 is same as input x. The values of str refer to the function `initializer`. Default: 'normal'. 45 bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is 46 same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. 47 has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. 48 49 Raises: 50 ValueError: If weight_init or bias_init shape is incorrect. 51 52 Inputs: 53 - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(*B, N, C)`, 54 where :math:`*B` represents the batch size which can be multidimensional, :math:`N` and :math:`C` are the 55 size of the last two dimensions. If `transpose_a` is True, its shape should be :math:`(*B, C, N)`. 56 57 Outputs: 58 Tensor, the shape of the output tensor is :math:`(*B, N, M)`. 59 60 Examples: 61 >>> net = nn.GNNFeatureTransform(3, 4) 62 >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) 63 >>> net(input) 64 [[ 2.5246444 2.2738023 0.5711005 -3.9399147 ] 65 [ 1.0739875 4.0155234 0.94188046 -5.459526 ]] 66 """ 67 68 @cell_attr_register 69 def __init__(self, 70 in_channels, 71 out_channels, 72 weight_init='normal', 73 bias_init='zeros', 74 has_bias=True): 75 super(GNNFeatureTransform, self).__init__() 76 self.in_channels = Validator.check_positive_int(in_channels) 77 self.out_channels = Validator.check_positive_int(out_channels) 78 self.has_bias = Validator.check_bool(has_bias) 79 80 if isinstance(weight_init, Tensor): 81 if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \ 82 weight_init.shape[1] != in_channels: 83 raise ValueError("weight_init shape error") 84 85 self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight") 86 87 if self.has_bias: 88 if isinstance(bias_init, Tensor): 89 if bias_init.ndim != 1 or bias_init.shape[0] != out_channels: 90 raise ValueError("bias_init shape error") 91 92 self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") 93 94 self.matmul = P.MatMul(transpose_b=True) 95 self.bias_add = P.BiasAdd() 96 97 def construct(self, x): 98 tensor_shape = F.shape(x) 99 input_feature = F.reshape(x, (tensor_shape[0] * tensor_shape[1], tensor_shape[2])) 100 output = self.matmul(input_feature, self.weight) 101 if self.has_bias: 102 output = self.bias_add(output, self.bias) 103 output = F.reshape(output, (tensor_shape[0], tensor_shape[1], self.out_channels)) 104 return output 105 106 def extend_repr(self): 107 s = 'in_channels={}, out_channels={}'.format(self.in_channels, self.out_channels) 108 if self.has_bias: 109 s += ', has_bias={}'.format(self.has_bias) 110 return s 111 112 113class _BaseAggregator(nn.Cell): 114 """ 115 Base Aggregator of GNN 116 117 Args: 118 feature_in_dim (int): Node or edge input feature dim. 119 feature_out_dim (int): Node or edge outpout feature dim. 120 use_fc (bool): Specifies whether a linear transformation before message is aggregated. Default: True 121 weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype 122 is same as input x. The values of str refer to the function `initializer`. Default: 'normal'. 123 bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is 124 same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. 125 has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. 126 dropout_ratio (float): The keep rate of dropout layer, greater than 0 and less equal than 1. Default: None. 127 activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. 128 129 Examples: 130 >>> class MyAggregator(_BaseAggregator): 131 >>> def __init__(self): 132 >>> super(MyAggregator, self).__init__(self, feature_in_dim, feature_out_dim) 133 >>> self.reduce_mean = P.ReduceSum() 134 >>> 135 >>> def construct(self, x): 136 >>> return self.reduce_mean(x, 1) 137 """ 138 139 def __init__(self, 140 feature_in_dim, 141 feature_out_dim, 142 use_fc=True, 143 weight_init="normal", 144 bias_init="zeros", 145 has_bias=True, 146 dropout_ratio=None, 147 activation=None): 148 super(_BaseAggregator, self).__init__() 149 self.in_dim = feature_in_dim 150 self.out_dim = feature_out_dim 151 self.use_fc = use_fc 152 if self.use_fc: 153 self.weight_init = weight_init 154 self.bias_init = bias_init 155 self.has_bias = has_bias 156 self.fc = GNNFeatureTransform(self.in_dim, 157 self.out_dim, 158 weight_init=self.weight_init, 159 bias_init=self.bias_init, 160 has_bias=self.has_bias) 161 self.dropout_ratio = dropout_ratio 162 if self.dropout_ratio is not None: 163 self.dropout = nn.Dropout(keep_prob=self.dropout_ratio) 164 self.dropout_flag = self.dropout_ratio is not None 165 self.activation = get_activation(activation) 166 self.activation_flag = self.activation is not None 167 168 def construct(self, **kward): 169 """Must be overridden by all subclasses.""" 170 raise NotImplementedError 171 172 173class MeanAggregator(_BaseAggregator): 174 """ 175 Mean Aggregator of GNN 176 177 Args: 178 feature_in_dim (int): Node or edge input feature dim. 179 feature_out_dim (int): Node or edge outpout feature dim. 180 use_fc (bool): Specifies whether a linear transformation before message is aggregated. Default: True 181 weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype 182 is same as input x. The values of str refer to the function `initializer`. Default: 'normal'. 183 bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is 184 same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. 185 has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. 186 dropout_ratio (float): The keep rate of dropout layer, greater than 0 and less equal than 1. Default: None. 187 activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. 188 189 Examples: 190 >>> net = MeanAggregator(32, 64, activation="relu", dropout=0.5) 191 >>> input_data = Tensor(np.array(np.random.rand(32, 3, 32), dtypy=np.float32)) 192 >>> output = net(input_data) 193 """ 194 195 def __init__(self, 196 feature_in_dim, 197 feature_out_dim, 198 use_fc=True, 199 weight_init="normal", 200 bias_init="zeros", 201 has_bias=True, 202 dropout_ratio=None, 203 activation=None): 204 super(MeanAggregator, self).__init__( 205 feature_in_dim, 206 feature_out_dim, 207 use_fc, 208 weight_init, 209 bias_init, 210 has_bias, 211 dropout_ratio, 212 activation) 213 self.reduce_mean = P.ReduceMean(keep_dims=False) 214 215 def construct(self, input_feature): 216 if self.use_fc: 217 input_feature = self.fc(input_feature) 218 if self.dropout_flag: 219 input_feature = self.dropout(input_feature) 220 if self.activation_flag: 221 input_feature = self.activation(input_feature) 222 output_feature = self.reduce_mean(input_feature, 1) 223 return output_feature 224 225 226class AttentionHead(nn.Cell): 227 """ 228 Attention Head for Graph Attention Networks. 229 230 Args: 231 in_channel (int): The number of input channel, input feature dim. 232 out_channel (int): The number of output channel, output feature dim. 233 in_drop_ratio (float): Input feature dropout ratio, default 0.0. 234 coef_drop_ratio (float): Coefficient dropout ratio, default 0.0. 235 residual (bool): Whether to use residual connection, default False. 236 coef_activation (Cell): The attention coefficient activation function, 237 default nn.LeakyReLU(). 238 activation (Cell): The output activation function, default nn.ELU(). 239 240 Inputs: 241 - **input_feature** (Tensor) - Tensor of shape : (batch_size, num_nodes, feature_dim). 242 - **bias_mat** (Tensor) - Tensor of shape : (batch_size, num_nodes, num_nodes). 243 244 Examples: 245 >>> head = AttentionHead(1433, 246 8, 247 in_drop_ratio=0.6, 248 coef_drop_ratio=0.6, 249 residual=False) 250 >>> input_data = Tensor(np.array(np.random.rand(1, 2708, 1433), dtypy=np.float32)) 251 >>> output = net(input_data) 252 """ 253 254 def __init__(self, 255 in_channel, 256 out_channel, 257 in_drop_ratio=0.0, 258 coef_drop_ratio=0.0, 259 residual=False, 260 coef_activation=nn.LeakyReLU(), 261 activation=nn.ELU()): 262 super(AttentionHead, self).__init__() 263 self.in_channel = Validator.check_positive_int(in_channel) 264 self.out_channel = Validator.check_positive_int(out_channel) 265 self.in_drop_ratio = in_drop_ratio 266 self.in_drop = nn.Dropout(keep_prob=1 - in_drop_ratio) 267 self.in_drop_2 = nn.Dropout(keep_prob=1 - in_drop_ratio) 268 self.feature_transform = GNNFeatureTransform( 269 in_channels=self.in_channel, 270 out_channels=self.out_channel, 271 has_bias=False) 272 273 self.f_1_transform = GNNFeatureTransform( 274 in_channels=self.out_channel, 275 out_channels=1) 276 self.f_2_transform = GNNFeatureTransform( 277 in_channels=self.out_channel, 278 out_channels=1) 279 self.softmax = nn.Softmax() 280 281 self.coef_drop = nn.Dropout(keep_prob=1 - coef_drop_ratio) 282 self.batch_matmul = P.BatchMatMul() 283 self.bias_add = P.BiasAdd() 284 self.bias = Parameter(initializer('zeros', self.out_channel), name='bias') 285 self.residual = Validator.check_bool(residual) 286 if self.residual: 287 if in_channel != out_channel: 288 self.residual_transform_flag = True 289 self.residual_transform = GNNFeatureTransform( 290 in_channels=self.in_channel, 291 out_channels=self.out_channel) 292 else: 293 self.residual_transform = None 294 self.coef_activation = coef_activation 295 self.activation = activation 296 297 def construct(self, input_feature, bias_mat): 298 input_feature = self.in_drop(input_feature) 299 300 feature = self.feature_transform(input_feature) 301 # self attention following the author 302 f_1 = self.f_1_transform(feature) 303 f_2 = self.f_2_transform(feature) 304 logits = f_1 + P.Transpose()(f_2, (0, 2, 1)) 305 logits = self.coef_activation(logits) + bias_mat 306 coefs = self.softmax(logits) 307 308 coefs = self.coef_drop(coefs) 309 feature = self.in_drop_2(feature) 310 311 ret = self.batch_matmul(coefs, feature) 312 ret = P.Squeeze(0)(ret) 313 ret = self.bias_add(ret, self.bias) 314 ret = P.ExpandDims()(ret, 0) 315 # residual connection 316 if self.residual: 317 if self.residual_transform_flag: 318 res = self.residual_transform(input_feature) 319 ret = ret + res 320 else: 321 ret = ret + input_feature 322 # activation 323 if self.activation is not None: 324 ret = self.activation(ret) 325 return ret 326 327 328class AttentionAggregator(nn.Cell): 329 """ 330 Attention Head for Graph Attention Networks,can be regarded as one 331 GAT layer. 332 333 Args: 334 in_channel (int): Input channel. 335 out_channel (int): Output channel. 336 num_heads (int): Number of attention heads for this layer, default 1. 337 in_drop_ratio (float): Input feature dropout ratio, default 0.0. 338 coef_drop_ratio (float): Coefficient dropout ratio, default 0.0. 339 activation (Cell): The output activation function, default nn.ELU(). 340 residual (bool): Whether to use residual connection, default False. 341 output_transform (str['concat', 'sum']): output transform for a layer, 342 default 'concat' 343 344 Inputs: 345 - **input_feature** (Tensor) - Tensor of shape : (batch_size, num_nodes, feature_dim). 346 - **bias_mat** (Tensor) - Tensor of shape : (batch_size, num_nodes, num_nodes). 347 348 Examples: 349 >>> input_data = Tensor(np.array(np.random.rand(1, 2708, 1433), dtype=np.float32)) 350 >>> biases = Tensor(np.array(np.random.rand(1, 2708, 2708), dtype=np.float32)) 351 >>> net = AttentionAggregator(1433, 352 8, 353 8) 354 >>> net(input_data, biases) 355 """ 356 357 def __init__(self, 358 in_channels, 359 out_channels, 360 num_heads=1, 361 in_drop=0.0, 362 coef_drop=0.0, 363 activation=nn.ELU(), 364 residual=False, 365 output_transform='concat'): 366 super(AttentionAggregator, self).__init__() 367 self.num_heads = num_heads 368 self.attns = [] 369 for _ in range(num_heads): 370 self.attns.append(AttentionHead(in_channels, 371 out_channels, 372 in_drop_ratio=in_drop, 373 coef_drop_ratio=coef_drop, 374 activation=activation, 375 residual=residual)) 376 self.attns = nn.layer.CellList(self.attns) 377 if output_transform == 'concat': 378 self.out_trans = P.Concat(-1) 379 elif output_transform == 'sum': 380 self.out_trans = P.AddN() 381 else: 382 raise ValueError 383 384 def construct(self, input_data, bias_mat): 385 res = () 386 for i in range(self.num_heads): 387 res += (self.attns[i](input_data, bias_mat),) 388 return self.out_trans(res) 389