• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test_net_infer """
16import numpy as np
17
18import mindspore.nn as nn
19from mindspore import Tensor, context
20from mindspore.common.parameter import Parameter
21from mindspore.common.initializer import initializer
22import mindspore.ops.operations as op
23
24def test_net_infer():
25    """ test_net_infer """
26    class Net(nn.Cell):
27        """ Net definition """
28
29        def __init__(self):
30            super(Net, self).__init__()
31            self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
32            self.bn = nn.BatchNorm2d(64)
33            self.fc = nn.Dense(64, 10)
34            self.relu = nn.ReLU()
35            self.flatten = nn.Flatten()
36
37        def construct(self, x):
38            x = self.conv(x)
39            x = self.relu(x)
40            x = self.flatten(x)
41            out = self.fc(x)
42            return out
43    Tensor(np.random.randint(0, 255, [1, 3, 224, 224]))
44    Net()
45
46
47def test_assign_in_while():
48    context.set_context(device_target="Ascend")
49    context.set_context(mode=context.GRAPH_MODE)
50    class Net(nn.Cell):
51        def __init__(self, input_shape):
52            super().__init__()
53            self.assign = op.Assign()
54            self.inputdata = Parameter(initializer(1, input_shape), name="global_step")
55
56        def construct(self, x, y, z):
57            out = z
58            while x < y:
59                inputdata = self.inputdata
60                x = x + 1
61                out = self.assign(inputdata, z)
62            return out
63
64    x = Tensor(np.array(1).astype(np.int32))
65    y = Tensor(np.array(3).astype(np.int32))
66    input_shape = (1024, 512)
67    z = Tensor(np.random.randn(*input_shape).astype(np.float32))
68    net = Net(input_shape)
69    net(x, y, z)
70
71
72def test_dup_context():
73    """ different func_with_fv in net1 and net2 should produce 2 different FuncGraphAbstractClosure and
74        Evaluator.
75    """
76    context.set_context(mode=context.GRAPH_MODE)
77
78    class Net(nn.Cell):
79        def __init__(self):
80            super().__init__()
81
82        def construct(self, x):
83            def identity(f):
84                return f
85
86            def func_with_fv():
87                return x
88
89            def net1():
90                local_func = identity(func_with_fv)
91                out = local_func() + 20.0
92                return out
93
94            def net2():
95                local_func = identity(func_with_fv)
96                out = local_func() + 15.0
97                return out
98
99            return net1() + net2()
100
101    Net()(Tensor(np.array(5.0).astype(np.float32)))
102
103
104def test_maybe_poly_func():
105    """ different func_with_fv in net1 and net2 may produce poly node. """
106    context.set_context(mode=context.GRAPH_MODE)
107
108    class Net(nn.Cell):
109        def __init__(self):
110            super().__init__()
111
112        def construct(self, x, y, z):
113            def identity(f, inp):
114                return f(inp)
115
116            def func_with_fv(yy):
117                return (x, yy)
118
119            def make_call():
120                out1 = identity(func_with_fv, y)
121                out2 = identity(func_with_fv, z)
122                return (out1, out2)
123
124            return make_call()
125
126    y_input = Tensor(np.array([1, 2]).astype(np.int32))
127    z_input = Tensor(np.array([[2, 2], [3, 3]]).astype(np.int32))
128    Net()(Tensor(np.array(1).astype(np.int32)), y_input, z_input)
129