• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test partial"""
16from functools import partial
17
18import numpy as np
19import pytest
20
21from mindspore import nn, Tensor, context
22
23context.set_context(mode=context.GRAPH_MODE)
24
25def test_partial_pos_arg():
26    class Net(nn.Cell):
27        def __init__(self):
28            super(Net, self).__init__()
29
30        def show(self, x, y, z):
31            return x, y, z
32
33        def construct(self, x, y, z):
34            f = partial(self.show, x)
35            ret = f(y, z)
36            return ret
37
38    x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
39    y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
40    z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
41    net = Net()
42    net(x, y, z)
43
44def test_partial_key_ward_arg():
45    class Net(nn.Cell):
46        def __init__(self):
47            super(Net, self).__init__()
48
49        def show(self, x, y, z):
50            return x, y, z
51
52        def construct(self, x, y, z):
53            f = partial(self.show, x=x)
54            ret = f(y=y, z=z)
55            return ret
56
57    x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
58    y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
59    z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
60    net = Net()
61    net(x, y, z)
62
63def test_partial_key_ward_arg_update():
64    class Net(nn.Cell):
65        def __init__(self):
66            super(Net, self).__init__()
67
68        def show(self, x, y, z):
69            return x, y, z
70
71        def construct(self, x, y, z):
72            f = partial(self.show, x=x, y=y)
73            ret = f(y=y, z=z)
74            return ret
75
76    x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
77    y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
78    z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
79    net = Net()
80    net(x, y, z)
81
82
83def test_partial_key_ward_arg_and_pos_arg():
84    class Net(nn.Cell):
85        def __init__(self):
86            super(Net, self).__init__()
87
88        def show(self, x, y, z):
89            return x, y, z
90
91        def construct(self, x, y, z):
92            f = partial(self.show, y=y)
93            ret = f(2, z=z)
94            return ret
95
96    x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
97    y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
98    z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
99    net = Net()
100    net(x, y, z)
101
102
103def test_partial_pos_arg_const():
104    class Net(nn.Cell):
105        def __init__(self):
106            super(Net, self).__init__()
107
108        def show(self, x, y, z):
109            return x, y, z
110
111        def construct(self):
112            f = partial(self.show, 1)
113            ret = f(2, 3)
114            return ret
115
116    net = Net()
117    assert net() == (1, 2, 3)
118
119def test_partial_key_ward_arg_const():
120    class Net(nn.Cell):
121        def __init__(self):
122            super(Net, self).__init__()
123
124        def show(self, x, y, z):
125            return x, y, z
126
127        def construct(self):
128            f = partial(self.show, x=1)
129            ret = f(y=2, z=3)
130            return ret
131
132    net = Net()
133    assert net() == (1, 2, 3)
134
135def test_partial_key_ward_arg_update_const():
136    class Net(nn.Cell):
137        def __init__(self):
138            super(Net, self).__init__()
139
140        def show(self, x, y, z):
141            return x, y, z
142
143        def construct(self):
144            f = partial(self.show, x=1, y=2)
145            ret = f(y=3, z=4)
146            return ret
147
148    net = Net()
149    assert net() == (1, 3, 4)
150
151
152def test_partial_key_ward_arg_and_pos_arg_const():
153    class Net(nn.Cell):
154        def __init__(self):
155            super(Net, self).__init__()
156
157        def show(self, x, y, z):
158            return x, y, z
159
160        def construct(self):
161            f = partial(self.show, y=2)
162            ret = f(1, z=3)
163            return ret
164
165    net = Net()
166    assert net() == (1, 2, 3)
167
168
169def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_x():
170    class Net(nn.Cell):
171        def __init__(self):
172            super(Net, self).__init__()
173
174        def show(self, x, y, z):
175            return x, y, z
176
177        def construct(self):
178            f = partial(self.show, x=1)
179            ret = f(1, 2, 3)
180            return ret
181
182    net = Net()
183    with pytest.raises(TypeError) as ex:
184        net()
185    assert "Multiply values for specific argument: x" in str(ex.value)
186
187
188def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_y():
189    class Net(nn.Cell):
190        def __init__(self):
191            super(Net, self).__init__()
192
193        def show(self, x, y, z):
194            return x, y, z
195
196        def construct(self):
197            f = partial(self.show, y=2)
198            ret = f(1, 2, z=3)
199            return ret
200
201    net = Net()
202    with pytest.raises(TypeError) as ex:
203        net()
204    assert "Multiply values for specific argument: y" in str(ex.value)
205
206
207def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_z():
208    class Net(nn.Cell):
209        def __init__(self):
210            super(Net, self).__init__()
211
212        def show(self, x, y, z):
213            return x, y, z
214
215        def construct(self):
216            f = partial(self.show, z=1)
217            ret = f(1, 2, 3)
218            return ret
219
220    net = Net()
221    with pytest.raises(TypeError) as ex:
222        net()
223    assert "Multiply values for specific argument: z" in str(ex.value)
224