• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16test nn.Tril()
17"""
18import numpy as np
19
20import mindspore.nn as nn
21from mindspore import Tensor
22from mindspore import context
23
24context.set_context(mode=context.GRAPH_MODE)
25
26
27def test_tril():
28    class Net(nn.Cell):
29        def __init__(self):
30            super(Net, self).__init__()
31            self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
32
33        def construct(self):
34            tril = nn.Tril()
35            return tril(self.value, 0)
36
37    net = Net()
38    out = net()
39    assert np.sum(out.asnumpy()) == 34
40
41
42def test_tril_1():
43    class Net(nn.Cell):
44        def __init__(self):
45            super(Net, self).__init__()
46            self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
47
48        def construct(self):
49            tril = nn.Tril()
50            return tril(self.value, 1)
51
52    net = Net()
53    out = net()
54    assert np.sum(out.asnumpy()) == 42
55
56
57def test_tril_2():
58    class Net(nn.Cell):
59        def __init__(self):
60            super(Net, self).__init__()
61            self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
62
63        def construct(self):
64            tril = nn.Tril()
65            return tril(self.value, -1)
66
67    net = Net()
68    out = net()
69    assert np.sum(out.asnumpy()) == 19
70
71
72def test_tril_parameter():
73    class Net(nn.Cell):
74        def __init__(self):
75            super(Net, self).__init__()
76
77        def construct(self, x):
78            tril = nn.Tril()
79            return tril(x, 0)
80
81    net = Net()
82    net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
83
84
85def test_tril_parameter_1():
86    class Net(nn.Cell):
87        def __init__(self):
88            super(Net, self).__init__()
89
90        def construct(self, x):
91            tril = nn.Tril()
92            return tril(x, 1)
93
94    net = Net()
95    net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
96
97
98def test_tril_parameter_2():
99    class Net(nn.Cell):
100        def __init__(self):
101            super(Net, self).__init__()
102
103        def construct(self, x):
104            tril = nn.Tril()
105            return tril(x, -1)
106
107    net = Net()
108    net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
109