• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""test gnn aggregator."""
16import numpy as np
17from aggregator import MeanAggregator, AttentionHead, AttentionAggregator
18
19import mindspore.context as context
20import mindspore.nn as nn
21import mindspore.ops.composite as C
22from mindspore import Tensor
23from mindspore.common.api import _cell_graph_executor
24
25context.set_context(mode=context.GRAPH_MODE)
26
27
28grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True)
29
30
31class MeanAggregatorGrad(nn.Cell):
32    """Backward of MeanAggregator"""
33
34    def __init__(self, network):
35        super(MeanAggregatorGrad, self).__init__()
36        self.grad_op = grad_all_with_sens
37        self.network = network
38
39    def construct(self, x, sens):
40        grad_op = self.grad_op(self.network)(x, sens)
41        return grad_op
42
43
44def test_MeanAggregator():
45    """Compile MeanAggregator forward graph"""
46    aggregator = MeanAggregator(32, 64, activation="relu", dropout_ratio=0.5)
47    input_data = Tensor(np.array(np.random.rand(32, 3, 32), dtype=np.float32))
48    _cell_graph_executor.compile(aggregator, input_data)
49
50
51def test_MeanAggregator_grad():
52    """Compile MeanAggregator backward graph"""
53    aggregator = MeanAggregator(32, 64, activation="relu", dropout_ratio=0.5)
54    input_data = Tensor(np.array(np.random.rand(32, 3, 32), dtype=np.float32))
55    sens = Tensor(np.ones([32, 64]).astype(np.float32))
56    grad_op = MeanAggregatorGrad(aggregator)
57    _cell_graph_executor.compile(grad_op, input_data, sens)
58
59
60def test_AttentionHead():
61    """Compile AttentionHead forward graph"""
62    head = AttentionHead(1433,
63                         8,
64                         in_drop_ratio=0.6,
65                         coef_drop_ratio=0.6,
66                         residual=False)
67    input_data = Tensor(np.array(np.random.rand(1, 2708, 1433), dtype=np.float32))
68    biases = Tensor(np.array(np.random.rand(1, 2708, 2708), dtype=np.float32))
69    _cell_graph_executor.compile(head, input_data, biases)
70
71
72def test_AttentionAggregator():
73    input_data = Tensor(np.array(np.random.rand(1, 2708, 1433), dtype=np.float32))
74    biases = Tensor(np.array(np.random.rand(1, 2708, 2708), dtype=np.float32))
75    net = AttentionAggregator(1433, 8, 8)
76    _cell_graph_executor.compile(net, input_data, biases)
77