1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16@File : test_create_obj.py 17@Author: 18@Date : 2019-06-26 19@Desc : test create object instance on parse function, eg: 'construct' 20 Support class : nn.Cell ops.Primitive 21 Support parameter: type is define on function 'ValuePtrToPyData' 22 (int,float,string,bool,tensor) 23""" 24import logging 25import numpy as np 26 27import mindspore.nn as nn 28from mindspore import context 29from mindspore.common.api import ms_function 30from mindspore.common import Tensor, Parameter 31from mindspore.ops import operations as P 32from ...ut_filter import non_graph_engine 33 34log = logging.getLogger("test") 35log.setLevel(level=logging.ERROR) 36 37 38class Net(nn.Cell): 39 """ Net definition """ 40 41 def __init__(self): 42 super(Net, self).__init__() 43 self.softmax = nn.Softmax(0) 44 self.axis = 0 45 46 def construct(self, x): 47 x = nn.Softmax(self.axis)(x) 48 return x 49 50 51# Test: Create Cell OR Primitive instance on construct 52@non_graph_engine 53def test_create_cell_object_on_construct(): 54 """ test_create_cell_object_on_construct """ 55 log.debug("begin test_create_object_on_construct") 56 context.set_context(mode=context.GRAPH_MODE) 57 np1 = np.random.randn(2, 3, 4, 5).astype(np.float32) 58 input_me = Tensor(np1) 59 60 net = Net() 61 output = net(input_me) 62 out_me1 = output.asnumpy() 63 print(np1) 64 print(out_me1) 65 log.debug("finished test_create_object_on_construct") 66 67 68# Test: Create Cell OR Primitive instance on construct 69class Net1(nn.Cell): 70 """ Net1 definition """ 71 72 def __init__(self): 73 super(Net1, self).__init__() 74 self.add = P.Add() 75 76 @ms_function 77 def construct(self, x, y): 78 add = P.Add() 79 result = add(x, y) 80 return result 81 82 83@non_graph_engine 84def test_create_primitive_object_on_construct(): 85 """ test_create_primitive_object_on_construct """ 86 log.debug("begin test_create_object_on_construct") 87 x = Tensor(np.array([[1, 2, 3], [1, 2, 3]], np.float32)) 88 y = Tensor(np.array([[2, 3, 4], [1, 1, 2]], np.float32)) 89 90 net = Net1() 91 net.construct(x, y) 92 log.debug("finished test_create_object_on_construct") 93 94 95# Test: Create Cell OR Primitive instance on construct use many parameter 96class NetM(nn.Cell): 97 """ NetM definition """ 98 99 def __init__(self, name, axis): 100 super(NetM, self).__init__() 101 # self.relu = nn.ReLU() 102 self.name = name 103 self.axis = axis 104 self.softmax = nn.Softmax(self.axis) 105 106 def construct(self, x): 107 x = self.softmax(x) 108 return x 109 110 111class NetC(nn.Cell): 112 """ NetC definition """ 113 114 def __init__(self, tensor): 115 super(NetC, self).__init__() 116 self.tensor = tensor 117 118 def construct(self, x): 119 x = NetM("test", 1)(x) 120 return x 121 122 123# Test: Create Cell OR Primitive instance on construct 124@non_graph_engine 125def test_create_cell_object_on_construct_use_many_parameter(): 126 """ test_create_cell_object_on_construct_use_many_parameter """ 127 log.debug("begin test_create_object_on_construct") 128 context.set_context(mode=context.GRAPH_MODE) 129 np1 = np.random.randn(2, 3, 4, 5).astype(np.float32) 130 input_me = Tensor(np1) 131 132 net = NetC(input_me) 133 output = net(input_me) 134 out_me1 = output.asnumpy() 135 print(np1) 136 print(out_me1) 137 log.debug("finished test_create_object_on_construct") 138 139 140class NetD(nn.Cell): 141 """ NetD definition """ 142 143 def __init__(self): 144 super(NetD, self).__init__() 145 146 def construct(self, x, y): 147 concat = P.Concat(axis=1) 148 return concat((x, y)) 149 150 151# Test: Create Cell OR Primitive instance on construct 152@non_graph_engine 153def test_create_primitive_object_on_construct_use_kwargs(): 154 """ test_create_primitive_object_on_construct_use_kwargs """ 155 log.debug("begin test_create_primitive_object_on_construct_use_kwargs") 156 context.set_context(mode=context.GRAPH_MODE) 157 x = Tensor(np.array([[0, 1], [2, 1]]).astype(np.float32)) 158 y = Tensor(np.array([[0, 1], [2, 1]]).astype(np.float32)) 159 net = NetD() 160 net(x, y) 161 log.debug("finished test_create_primitive_object_on_construct_use_kwargs") 162 163 164class NetE(nn.Cell): 165 """ NetE definition """ 166 167 def __init__(self): 168 super(NetE, self).__init__() 169 self.w = Parameter(Tensor(np.ones([16, 16, 3, 3]).astype(np.float32)), name='w') 170 171 def construct(self, x): 172 out_channel = 16 173 kernel_size = 3 174 conv2d = P.Conv2D(out_channel, 175 kernel_size, 176 1, 177 pad_mode='valid', 178 pad=0, 179 stride=1, 180 dilation=1, 181 group=1) 182 return conv2d(x, self.w) 183 184 185# Test: Create Cell OR Primitive instance on construct 186@non_graph_engine 187def test_create_primitive_object_on_construct_use_args_and_kwargs(): 188 """ test_create_primitive_object_on_construct_use_args_and_kwargs """ 189 log.debug("begin test_create_primitive_object_on_construct_use_args_and_kwargs") 190 context.set_context(mode=context.GRAPH_MODE) 191 inputs = Tensor(np.ones([1, 16, 16, 16]).astype(np.float32)) 192 net = NetE() 193 net(inputs) 194 log.debug("finished test_create_primitive_object_on_construct_use_args_and_kwargs") 195