• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16@File  : test_create_obj.py
17@Author:
18@Date  : 2019-06-26
19@Desc  : test create object instance on parse function, eg: 'construct'
20         Support class : nn.Cell ops.Primitive
21         Support parameter: type is define on function 'ValuePtrToPyData'
22                            (int,float,string,bool,tensor)
23"""
24import logging
25import numpy as np
26
27import mindspore.nn as nn
28from mindspore import context
29from mindspore.common.api import ms_function
30from mindspore.common import Tensor, Parameter
31from mindspore.ops import operations as P
32from ...ut_filter import non_graph_engine
33
34log = logging.getLogger("test")
35log.setLevel(level=logging.ERROR)
36
37
38class Net(nn.Cell):
39    """ Net definition """
40
41    def __init__(self):
42        super(Net, self).__init__()
43        self.softmax = nn.Softmax(0)
44        self.axis = 0
45
46    def construct(self, x):
47        x = nn.Softmax(self.axis)(x)
48        return x
49
50
51# Test: Create Cell OR Primitive instance on construct
52@non_graph_engine
53def test_create_cell_object_on_construct():
54    """ test_create_cell_object_on_construct """
55    log.debug("begin test_create_object_on_construct")
56    context.set_context(mode=context.GRAPH_MODE)
57    np1 = np.random.randn(2, 3, 4, 5).astype(np.float32)
58    input_me = Tensor(np1)
59
60    net = Net()
61    output = net(input_me)
62    out_me1 = output.asnumpy()
63    print(np1)
64    print(out_me1)
65    log.debug("finished test_create_object_on_construct")
66
67
68# Test: Create Cell OR Primitive instance on construct
69class Net1(nn.Cell):
70    """ Net1 definition """
71
72    def __init__(self):
73        super(Net1, self).__init__()
74        self.add = P.Add()
75
76    @ms_function
77    def construct(self, x, y):
78        add = P.Add()
79        result = add(x, y)
80        return result
81
82
83@non_graph_engine
84def test_create_primitive_object_on_construct():
85    """ test_create_primitive_object_on_construct """
86    log.debug("begin test_create_object_on_construct")
87    x = Tensor(np.array([[1, 2, 3], [1, 2, 3]], np.float32))
88    y = Tensor(np.array([[2, 3, 4], [1, 1, 2]], np.float32))
89
90    net = Net1()
91    net.construct(x, y)
92    log.debug("finished test_create_object_on_construct")
93
94
95# Test: Create Cell OR Primitive instance on construct use many parameter
96class NetM(nn.Cell):
97    """ NetM definition """
98
99    def __init__(self, name, axis):
100        super(NetM, self).__init__()
101        # self.relu = nn.ReLU()
102        self.name = name
103        self.axis = axis
104        self.softmax = nn.Softmax(self.axis)
105
106    def construct(self, x):
107        x = self.softmax(x)
108        return x
109
110
111class NetC(nn.Cell):
112    """ NetC definition """
113
114    def __init__(self, tensor):
115        super(NetC, self).__init__()
116        self.tensor = tensor
117
118    def construct(self, x):
119        x = NetM("test", 1)(x)
120        return x
121
122
123# Test: Create Cell OR Primitive instance on construct
124@non_graph_engine
125def test_create_cell_object_on_construct_use_many_parameter():
126    """ test_create_cell_object_on_construct_use_many_parameter """
127    log.debug("begin test_create_object_on_construct")
128    context.set_context(mode=context.GRAPH_MODE)
129    np1 = np.random.randn(2, 3, 4, 5).astype(np.float32)
130    input_me = Tensor(np1)
131
132    net = NetC(input_me)
133    output = net(input_me)
134    out_me1 = output.asnumpy()
135    print(np1)
136    print(out_me1)
137    log.debug("finished test_create_object_on_construct")
138
139
140class NetD(nn.Cell):
141    """ NetD definition """
142
143    def __init__(self):
144        super(NetD, self).__init__()
145
146    def construct(self, x, y):
147        concat = P.Concat(axis=1)
148        return concat((x, y))
149
150
151# Test: Create Cell OR Primitive instance on construct
152@non_graph_engine
153def test_create_primitive_object_on_construct_use_kwargs():
154    """ test_create_primitive_object_on_construct_use_kwargs """
155    log.debug("begin test_create_primitive_object_on_construct_use_kwargs")
156    context.set_context(mode=context.GRAPH_MODE)
157    x = Tensor(np.array([[0, 1], [2, 1]]).astype(np.float32))
158    y = Tensor(np.array([[0, 1], [2, 1]]).astype(np.float32))
159    net = NetD()
160    net(x, y)
161    log.debug("finished test_create_primitive_object_on_construct_use_kwargs")
162
163
164class NetE(nn.Cell):
165    """ NetE definition """
166
167    def __init__(self):
168        super(NetE, self).__init__()
169        self.w = Parameter(Tensor(np.ones([16, 16, 3, 3]).astype(np.float32)), name='w')
170
171    def construct(self, x):
172        out_channel = 16
173        kernel_size = 3
174        conv2d = P.Conv2D(out_channel,
175                          kernel_size,
176                          1,
177                          pad_mode='valid',
178                          pad=0,
179                          stride=1,
180                          dilation=1,
181                          group=1)
182        return conv2d(x, self.w)
183
184
185# Test: Create Cell OR Primitive instance on construct
186@non_graph_engine
187def test_create_primitive_object_on_construct_use_args_and_kwargs():
188    """ test_create_primitive_object_on_construct_use_args_and_kwargs """
189    log.debug("begin test_create_primitive_object_on_construct_use_args_and_kwargs")
190    context.set_context(mode=context.GRAPH_MODE)
191    inputs = Tensor(np.ones([1, 16, 16, 16]).astype(np.float32))
192    net = NetE()
193    net(inputs)
194    log.debug("finished test_create_primitive_object_on_construct_use_args_and_kwargs")
195