• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import numpy as np
2
3import mindspore.nn as nn
4from mindspore import context
5from mindspore.common.parameter import Parameter
6from mindspore.common.tensor import Tensor
7import mindspore.ops as ops
8import mindspore
9
10context.set_context(mode=context.GRAPH_MODE)
11
12
13class TestNoReturn(nn.Cell):
14    def __init__(self):
15        super(TestNoReturn, self).__init__()
16        self.m = 1
17
18    def construct(self, x, y):
19        and_v = x * y
20        and_v += 1
21        # return and_v
22
23
24def test_no_return():
25    net = TestNoReturn()
26    x = Tensor(np.ones([2, 2], np.float))
27    y = Tensor(np.zeros([2, 2], np.float))
28    ret = net(x, y)
29    print(ret)
30
31
32class TestSuper(nn.Cell):
33    def __init__(self):
34        super().__init__()
35        self.m = 1
36
37    def construct(self, x, y):
38        super(TestSuper, 2, 3).aa()
39        and_v = x * y
40        return and_v
41
42
43def test_super():
44    net = TestSuper()
45    x = Tensor(np.ones([2, 2], np.float))
46    y = Tensor(np.zeros([2, 2], np.float))
47    ret = net(x, y)
48    print(ret)
49
50
51class TestCompare(nn.Cell):
52    def __init__(self):
53        super(TestCompare, self).__init__()
54        self.m = 1
55
56    def construct(self, x, y):
57        return x > y > 10
58
59
60def test_compare():
61    net = TestCompare()
62    x = Tensor(np.ones([2, 2], np.float))
63    y = Tensor(np.zeros([2, 2], np.float))
64    ret = net(x, y)
65    print(ret)
66
67
68class TestUndefMemberChange(nn.Cell):
69    def __init__(self):
70        super(TestUndefMemberChange, self).__init__()
71        self.m = 1
72
73    def construct(self, x, y):
74        self.t = x
75        return x > y
76
77
78def test_undef_member_changer():
79    net = TestUndefMemberChange()
80    x = Tensor(np.ones([2, 2], np.float))
81    y = Tensor(np.zeros([2, 2], np.float))
82    ret = net(x, y)
83    print(ret)
84
85
86class TestMemberChange(nn.Cell):
87    def __init__(self):
88        super(TestMemberChange, self).__init__()
89        self.t = Tensor(np.zeros([2, 2], np.float))
90
91    def construct(self, x, y):
92        self.t = x
93        return x > y
94
95
96def test_member_changer():
97    net = TestMemberChange()
98    x = Tensor(np.ones([2, 2], np.float))
99    y = Tensor(np.zeros([2, 2], np.float))
100    ret = net(x, y)
101    print(ret)
102
103
104class TestUnsupportSTMT(nn.Cell):
105    def __init__(self):
106        super(TestUnsupportSTMT, self).__init__()
107        self.m = 1
108
109    def construct(self, x, y):
110        try:
111            val = x + y
112        finally:
113            val = x
114        return val
115
116
117def test_UnsupportSTMT():
118    net = TestUnsupportSTMT()
119    x = Tensor(np.ones([2, 2], np.float))
120    y = Tensor(np.zeros([2, 2], np.float))
121    ret = net(x, y)
122    print(ret)
123
124
125class TestUnsupportNum(nn.Cell):
126    def __init__(self):
127        super(TestUnsupportNum, self).__init__()
128        self.m = 1
129
130    def construct(self, x, y):
131        a = x + 3.14j
132        return a
133
134
135def test_UnsupportNum():
136    net = TestUnsupportNum()
137    x = Tensor(np.ones([2, 2], np.float))
138    y = Tensor(np.zeros([2, 2], np.float))
139    ret = net(x, y)
140    print(ret)
141
142
143class TestAssignAdd(nn.Cell):
144    def __init__(self):
145        super(TestAssignAdd, self).__init__()
146        self.m = 1
147
148    def construct(self, x, y):
149        x.id_ += y
150        # x[1] += y
151        return x
152
153
154def test_AssignAdd():
155    net = TestAssignAdd()
156    ret = net([3, 1], 2)
157    print(ret)
158
159
160class TestParseListComp(nn.Cell):
161    def __init__(self):
162        super(TestParseListComp, self).__init__()
163        self.m = 1
164
165    def construct(self, x, y):
166        ret = [m + y for l in x for m in l]
167        return ret
168
169
170def test_ParseListComp():
171    net = TestParseListComp()
172
173    ret = net([[1, 2], [3, 4]], 2)
174    print(ret)
175
176
177class TestAssign(nn.Cell):
178    def __init__(self):
179        super(TestAssign, self).__init__()
180        self.m = 1
181
182    def construct(self, x, y):
183        x.id_ = y
184        return x
185
186
187def test_Assign():
188    net = TestAssign()
189    ret = net([3, 1], 2)
190    print(ret)
191
192
193class TestAssignList(nn.Cell):
194    def __init__(self):
195        super(TestAssignList, self).__init__()
196        self.m = 1
197
198    def construct(self, x, y):
199        [m, n] = [x, y]
200        return m, n
201
202
203def test_AssignList():
204    net = TestAssignList()
205    ret = net([3, 1], 2)
206    print(ret)
207
208
209class TestParaDef(nn.Cell):
210    def __init__(self):
211        super(TestParaDef, self).__init__()
212        self.m = 1
213
214    def construct(self, x=1, y=1):
215        ret = x + y
216        return ret
217
218
219def test_para_def():
220    net = TestParaDef()
221    ret = net(1, 2)
222    print(ret)
223
224
225class TestParameterNameNone(nn.Cell):
226    def __init__(self):
227        super(TestParameterNameNone, self).__init__()
228        self.matmul = ops.MatMul()
229        # self.weight = Parameter(Tensor(np.ones((1, 2)), mindspore.float32), name="w", requires_grad=True)
230        self.weight = Parameter(Tensor(np.ones((1, 2)), mindspore.float32), name=None, requires_grad=True)
231
232    def construct(self, x):
233        out = self.matmul(self.weight, x)
234        return out
235
236
237def test_parameter_name_none():
238    net = TestParameterNameNone()
239    x = Tensor(np.ones((2, 1)), mindspore.float32)
240    print(net(x))
241
242
243class TestBranchReturn(nn.Cell):
244    def __init__(self):
245        super(TestBranchReturn, self).__init__()
246        self.m = 1
247
248    def construct(self, x):
249        if x > 0:
250            return x + 1
251
252        return x
253
254
255def test_branch_return():
256    net = TestBranchReturn()
257    print(net(1))
258
259
260class TestSliceNotInt(nn.Cell):
261    def __init__(self):
262        super(TestSliceNotInt, self).__init__()
263        self.m = 1
264
265    def construct(self, x):
266        s = "ABCDEFGHIJKL"
267        sl = slice(x, 4.5)
268        return s[sl]
269
270
271def test_slice_not_int():
272    net = TestSliceNotInt()
273    print(net(1))
274
275
276class TestSliceNotIntDefInInit(nn.Cell):
277    def __init__(self):
278        super(TestSliceNotIntDefInInit, self).__init__()
279        self.sl = slice(1, 4.5)
280
281    def construct(self, x):
282        s = "ABCDEFGHIJKL"
283        return s[self.sl]
284
285
286def test_slice_not_int_def_in_init():
287    net = TestSliceNotIntDefInInit()
288    print(net(1))
289
290
291class MatMulCell(nn.Cell):
292    def __init__(self):
293        super().__init__()
294        self.m = 1
295
296    def construct(self, x):
297        return x
298
299
300class TestCellPipelineStage(nn.Cell):
301    def __init__(self, strategy1, strategy2, param=None):
302        super().__init__()
303        self.block = nn.CellList()
304        cell = MatMulCell()
305        cell.pipeline_stage = -1
306        self.block.append(cell)
307        cell = MatMulCell()
308        cell.pipeline_stage = -1
309        self.block.append(cell)
310
311    def construct(self, x):
312        for i in range(2):
313            x = self.block[i](x)
314        return x
315
316
317def test_cell_pipeline_state():
318    strategy1 = Tensor((4, 1), mindspore.int64)
319    strategy2 = Tensor((2, 1), mindspore.int64)
320    net = TestCellPipelineStage(strategy1, strategy2)
321    print(net(1))
322
323
324class TestArgsKwArgs(nn.Cell):
325    def __init__(self):
326        super(TestArgsKwArgs, self).__init__()
327        self.m = 1
328
329    def construct(self, *args, **kwargs):
330        x = 0
331        for v in args:
332            x += v
333
334        # for k, v in kwargs.items():
335        #     x += v
336        return x
337
338
339def test_args_kwargs():
340    net = TestArgsKwArgs()
341    print(net(1, 2, 3, 4, k1=5, k2=6))
342
343
344class TestArgs(nn.Cell):
345    def __init__(self):
346        super(TestArgs, self).__init__()
347        self.m = 1
348
349    def construct(self, x, *args):
350        for v in args:
351            x += v
352
353        return x
354
355
356def test_args():
357    net = TestArgs()
358    print(net(1, 2, 3, 4))
359
360
361class TestNoDef(nn.Cell):
362    def __init__(self):
363        super().__init__()
364        self.m = 1
365
366    def construct(self, x):
367        x += self.y
368        return x
369
370
371def test_no_def():
372    net = TestNoDef()
373    print(net(1))
374