• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test view"""
16import pytest
17
18import mindspore as ms
19import mindspore.nn as nn
20import mindspore.common.initializer as init
21from mindspore import Tensor
22from mindspore import context
23
24context.set_context(mode=context.GRAPH_MODE)
25
26
27def test_view():
28    class Net(nn.Cell):
29        def __init__(self):
30            super(Net, self).__init__()
31            self.value = Tensor([[1, 2, 3], [4, 5, 6]])
32
33        def construct(self):
34            return self.value.view(-1)
35
36    net = Net()
37    net()
38
39
40def test_view_initializer():
41    class Net(nn.Cell):
42        def __init__(self):
43            super(Net, self).__init__()
44            self.value = init.initializer('normal', [2, 3], ms.float32)
45
46        def construct(self):
47            return self.value.view(-1)
48
49    net = Net()
50    net()
51
52
53def test_view_1():
54    class Net(nn.Cell):
55        def __init__(self):
56            super(Net, self).__init__()
57            self.value = Tensor([[1, 2, 3], [4, 5, 6]])
58
59        def construct(self):
60            return self.value.view((3, 2))
61
62    net = Net()
63    net()
64
65
66def test_view_2():
67    class Net(nn.Cell):
68        def __init__(self):
69            super(Net, self).__init__()
70            self.value = Tensor([[1, 2, 3], [4, 5, 6]])
71
72        def construct(self):
73            return self.value.view(3, 2)
74
75    net = Net()
76    net()
77
78
79def test_view_parameter():
80    class Net(nn.Cell):
81        def __init__(self):
82            super(Net, self).__init__()
83
84        def construct(self, x):
85            return x.view(-1)
86
87    net = Net()
88    net(Tensor([[1, 2, 3], [4, 5, 6]]))
89
90
91def test_view_parameter_1():
92    class Net(nn.Cell):
93        def __init__(self):
94            super(Net, self).__init__()
95
96        def construct(self, x):
97            return x.view((3, 2))
98
99    net = Net()
100    net(Tensor([[1, 2, 3], [4, 5, 6]]))
101
102
103def test_view_parameter_2():
104    class Net(nn.Cell):
105        def __init__(self):
106            super(Net, self).__init__()
107
108        def construct(self, x):
109            return x.view(3, 2)
110
111    net = Net()
112    net(Tensor([[1, 2, 3], [4, 5, 6]]))
113
114
115def test_view_shape_error():
116    class Net(nn.Cell):
117        def __init__(self):
118            super(Net, self).__init__()
119            self.value = Tensor([[1, 2, 3], [4, 5, 6]])
120
121        def construct(self):
122            return self.value.view()
123
124    net = Net()
125    with pytest.raises(ValueError) as ex:
126        net()
127    assert "The shape variable should not be empty" in str(ex.value)
128
129
130def test_view_shape_error_1():
131    class Net(nn.Cell):
132        def __init__(self):
133            super(Net, self).__init__()
134            self.value = Tensor([[1, 2, 3], [4, 5, 6]])
135
136        def construct(self):
137            return self.value.view((2, 3), (4, 5))
138
139    net = Net()
140    with pytest.raises(ValueError) as ex:
141        net()
142    assert "Only one tuple is needed" in str(ex.value)
143