• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import numpy as np
2
3import mindspore.nn as nn
4from mindspore import context, Tensor
5from mindspore.ops import operations as P
6from mindspore.ops import composite as C
7
8
9
10def setup_module(module):
11    context.set_context(mode=context.PYNATIVE_MODE)
12
13
14class Block1(nn.Cell):
15    """ Define Cell with tuple input as parameter."""
16
17    def __init__(self):
18        super(Block1, self).__init__()
19        self.mul = P.Mul()
20
21    def construct(self, tuple_xy):
22        x, y = tuple_xy
23        z = self.mul(x, y)
24        return z
25
26class Block2(nn.Cell):
27    """ definition with tuple in tuple output in Cell."""
28
29    def __init__(self):
30        super(Block2, self).__init__()
31        self.mul = P.Mul()
32        self.add = P.Add()
33
34    def construct(self, x, y):
35        z1 = self.mul(x, y)
36        z2 = self.add(z1, x)
37        z3 = self.add(z1, y)
38        return (z1, (z2, z3))
39
40class Net1(nn.Cell):
41    def __init__(self):
42        super(Net1, self).__init__()
43        self.block = Block1()
44
45    def construct(self, x, y):
46        res = self.block((x, y))
47        return res
48
49
50class Net2(nn.Cell):
51    def __init__(self):
52        super(Net2, self).__init__()
53        self.add = P.Add()
54        self.block = Block2()
55
56    def construct(self, x, y):
57        z1, (z2, z3) = self.block(x, y)
58        res = self.add(z1, z2)
59        res = self.add(res, z3)
60        return res
61
62def test_net():
63    x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32) * 2)
64    y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32) * 3)
65    net1 = Net1()
66    grad_op = C.GradOperation(get_all=True)
67    output = grad_op(net1)(x, y)
68    assert np.all(output[0].asnumpy() == y.asnumpy())
69    assert np.all(output[1].asnumpy() == x.asnumpy())
70
71    net2 = Net2()
72    output = grad_op(net2)(x, y)
73    expect_x = np.ones([1, 1, 3, 3]).astype(np.float32) * 10
74    expect_y = np.ones([1, 1, 3, 3]).astype(np.float32) * 7
75    assert np.all(output[0].asnumpy() == expect_x)
76    assert np.all(output[1].asnumpy() == expect_y)
77