• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test nn.Dense """
16import numpy as np
17import pytest
18
19import mindspore.context as context
20import mindspore.nn as nn
21from mindspore import Tensor
22from mindspore.ops import operations as P
23from mindspore.common.api import _cell_graph_executor
24from ..ut_filter import non_graph_engine
25
26
27def test_dense_none():
28    with pytest.raises(TypeError):
29        nn.Dense(3, 2, None, None)
30
31
32@non_graph_engine
33def test_dense_str_activation():
34    dense = nn.Dense(1, 1, activation='relu')
35    assert isinstance(dense.activation, nn.ReLU)
36
37    input_data = Tensor(np.random.randint(0, 255, [1, 1]).astype(np.float32))
38    dense(input_data)
39
40
41@non_graph_engine
42def test_dense_nn_activation_():
43    dense = nn.Dense(1, 1, activation=nn.ReLU())
44    assert isinstance(dense.activation, nn.ReLU)
45
46    input_data = Tensor(np.random.randint(0, 255, [1, 1]).astype(np.float32))
47    dense(input_data)
48
49
50@non_graph_engine
51def test_dense_ops_activation_():
52    dense = nn.Dense(1, 1, activation=P.ReLU())
53    assert isinstance(dense.activation, P.ReLU)
54
55    input_data = Tensor(np.random.randint(0, 255, [1, 1]).astype(np.float32))
56    dense(input_data)
57
58
59def test_dense_weight_error():
60    dim_error = Tensor(np.array([[[0.1], [0.3], [0.6]], [[0.4], [0.5], [0.2]]]))
61    with pytest.raises(ValueError):
62        nn.Dense(3, 2, dim_error)
63
64    shape_error = Tensor(np.array([[0.1, 0.3, 0.6], [0.4, 0.5, 0.2]]))
65    with pytest.raises(ValueError):
66        nn.Dense(2, 2, shape_error)
67    with pytest.raises(ValueError):
68        nn.Dense(3, 3, shape_error)
69
70
71def test_dense_bias_error():
72    dim_error = Tensor(np.array([[0.5, 0.3]]))
73    with pytest.raises(ValueError):
74        nn.Dense(3, 2, bias_init=dim_error)
75
76    shape_error = Tensor(np.array([0.5, 0.3, 0.4]))
77    with pytest.raises(ValueError):
78        nn.Dense(3, 2, bias_init=shape_error)
79
80
81def test_dense_channels_error():
82    with pytest.raises(ValueError):
83        nn.Dense(3, 0)
84
85    with pytest.raises(ValueError):
86        nn.Dense(-1, 2)
87
88
89class Net(nn.Cell):
90    """ Net definition """
91
92    def __init__(self,
93                 input_channels,
94                 output_channels,
95                 weight='normal',
96                 bias='zeros',
97                 has_bias=True,
98                 activation=None):
99        super(Net, self).__init__()
100        self.dense = nn.Dense(input_channels,
101                              output_channels,
102                              weight,
103                              bias,
104                              has_bias,
105                              activation=activation)
106
107    def construct(self, input_x):
108        return self.dense(input_x)
109
110
111def test_compile():
112    """ test_compile """
113    # has bias
114    weight = Tensor(np.random.randint(0, 255, [8, 64]).astype(np.float32))
115    bias = Tensor(np.random.randint(0, 255, [8]).astype(np.float32))
116    net = Net(64, 8, weight=weight, bias=bias)
117    input_data = Tensor(np.random.randint(0, 255, [128, 64]).astype(np.float32))
118    _cell_graph_executor.compile(net, input_data)
119
120    # training
121    net_train = Net(64, 8, weight=weight, bias=bias)
122    net_train.set_train()
123    _cell_graph_executor.compile(net_train, input_data)
124
125
126def test_compile_2():
127    """ test_compile_2 """
128    # no bias
129    weight = Tensor(np.random.randint(0, 255, [8, 64]).astype(np.float32))
130    net = Net(64, 8, weight=weight, has_bias=False)
131    input_data = Tensor(np.random.randint(0, 255, [128, 64]).astype(np.float32))
132    _cell_graph_executor.compile(net, input_data)
133
134    # training
135    net_train = Net(64, 8, weight=weight, has_bias=False)
136    net_train.set_train()
137    _cell_graph_executor.compile(net_train, input_data)
138
139
140def test_compile_3():
141    """ test_compile_3 """
142    # test for Graph mode
143    # has bias
144    context.set_context(mode=context.GRAPH_MODE)
145    net = Net(128, 10)
146    input_data = Tensor(np.random.randint(0, 255, [128, 128]).astype(np.float32))
147    _cell_graph_executor.compile(net, input_data)
148
149    # training
150    net_train = Net(128, 10)
151    net_train.set_train()
152    _cell_graph_executor.compile(net_train, input_data)
153
154
155def test_compile_4():
156    """ test_compile_4 """
157    # test for Graph mode
158    # no bias
159    context.set_context(mode=context.GRAPH_MODE)
160    net = Net(128, 10, has_bias=False)
161    input_data = Tensor(np.random.randint(0, 255, [128, 128]).astype(np.float32))
162    _cell_graph_executor.compile(net, input_data)
163
164    # training
165    net_train = Net(128, 10, has_bias=False)
166    net_train.set_train()
167    _cell_graph_executor.compile(net_train, input_data)
168