1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""test gnn aggregator.""" 16import numpy as np 17from aggregator import MeanAggregator, AttentionHead, AttentionAggregator 18 19import mindspore.context as context 20import mindspore.nn as nn 21import mindspore.ops.composite as C 22from mindspore import Tensor 23from mindspore.common.api import _cell_graph_executor 24 25context.set_context(mode=context.GRAPH_MODE) 26 27 28grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True) 29 30 31class MeanAggregatorGrad(nn.Cell): 32 """Backward of MeanAggregator""" 33 34 def __init__(self, network): 35 super(MeanAggregatorGrad, self).__init__() 36 self.grad_op = grad_all_with_sens 37 self.network = network 38 39 def construct(self, x, sens): 40 grad_op = self.grad_op(self.network)(x, sens) 41 return grad_op 42 43 44def test_MeanAggregator(): 45 """Compile MeanAggregator forward graph""" 46 aggregator = MeanAggregator(32, 64, activation="relu", dropout_ratio=0.5) 47 input_data = Tensor(np.array(np.random.rand(32, 3, 32), dtype=np.float32)) 48 _cell_graph_executor.compile(aggregator, input_data) 49 50 51def test_MeanAggregator_grad(): 52 """Compile MeanAggregator backward graph""" 53 aggregator = MeanAggregator(32, 64, activation="relu", dropout_ratio=0.5) 54 input_data = Tensor(np.array(np.random.rand(32, 3, 32), dtype=np.float32)) 55 sens = Tensor(np.ones([32, 64]).astype(np.float32)) 56 grad_op = MeanAggregatorGrad(aggregator) 57 _cell_graph_executor.compile(grad_op, input_data, sens) 58 59 60def test_AttentionHead(): 61 """Compile AttentionHead forward graph""" 62 head = AttentionHead(1433, 63 8, 64 in_drop_ratio=0.6, 65 coef_drop_ratio=0.6, 66 residual=False) 67 input_data = Tensor(np.array(np.random.rand(1, 2708, 1433), dtype=np.float32)) 68 biases = Tensor(np.array(np.random.rand(1, 2708, 2708), dtype=np.float32)) 69 _cell_graph_executor.compile(head, input_data, biases) 70 71 72def test_AttentionAggregator(): 73 input_data = Tensor(np.array(np.random.rand(1, 2708, 1433), dtype=np.float32)) 74 biases = Tensor(np.array(np.random.rand(1, 2708, 2708), dtype=np.float32)) 75 net = AttentionAggregator(1433, 8, 8) 76 _cell_graph_executor.compile(net, input_data, biases) 77