1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test nn ops """ 16import numpy as np 17from numpy.random import normal 18import pytest 19 20import mindspore.nn as nn 21import mindspore.context as context 22from mindspore.ops.composite import core 23from mindspore.common.api import ms_function 24 25from mindspore import Tensor 26from mindspore.ops import functional as F 27from mindspore.ops import prim_attr_register, PrimitiveWithInfer 28 29context.set_context(mode=context.GRAPH_MODE) 30 31 32class FakeOp(PrimitiveWithInfer): 33 @prim_attr_register 34 def __init__(self): 35 """""" 36 37 def infer_shape(self, x, y): 38 self.second_shape = y 39 self.add_prim_attr("second_shape", y) 40 return x 41 42 def infer_dtype(self, x, y): 43 return x 44 45 46# test the normal case that should generate independent primitive because of different 47# generated attributes after inference 48def test_conv2d_same_primitive(): 49 class Conv2DSameNet(nn.Cell): 50 def __init__(self): 51 super(Conv2DSameNet, self).__init__() 52 self.conv1 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True) 53 self.conv2 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True) 54 55 def construct(self, x, y): 56 r1 = self.conv1(x) 57 r2 = self.conv2(y) 58 return (r1, r2) 59 60 t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32)) 61 t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32)) 62 net = Conv2DSameNet() 63 net(t1, t2) 64 65 66# test free variable function list as parameter 67def test_remove_and_fv_2(): 68 @core(loop_can_uroll=True) 69 def inner_loop(x, input_data, fv_func_list): 70 ret = () 71 for fv_fn in fv_func_list: 72 ele = fv_fn(input_data) 73 ret += (ele,) 74 return ret 75 76 @ms_function 77 def out_loop(input1, input_data0, input_data1): 78 ret = () 79 80 def fv_func1(y): 81 return input1 * y 82 def fv_func2(y): 83 return input1 - y 84 fv_func_list = [fv_func1, fv_func2] 85 ele0 = inner_loop(input1, input_data0, fv_func_list) 86 ele1 = inner_loop(input1, input_data1, fv_func_list) 87 ret = (ele0, ele1) 88 return ret 89 90 input_data0 = Tensor(normal(0, 0.1, (3, 3))) 91 input_data1 = Tensor(normal(0, 0.1, (3, 1))) 92 input1 = Tensor(normal(0, 0.1, (3, 3))) 93 out_loop(input1, input_data0, input_data1) 94 95 96# test cell as high order argument 97# The graph with free variables used as argument is not supported yet 98# because of the limit of inference specialize system 99def test_conv2d_op_with_argi_1(): 100 class Conv2dNet(nn.Cell): 101 def __init__(self): 102 super(Conv2dNet, self).__init__() 103 104 def construct(self, op, x): 105 return op(x) 106 107 class OpsNet(nn.Cell): 108 def __init__(self, net): 109 super(OpsNet, self).__init__() 110 self.opnet = net 111 self.conv2 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True) 112 113 def construct(self, x, y): 114 conv_op = self.conv2 115 a = self.opnet(conv_op, x) 116 b = self.opnet(conv_op, y) 117 return (a, b) 118 119 t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32)) 120 t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32)) 121 net = OpsNet(Conv2dNet()) 122 net(t1, t2) 123 124 125def test_conv2d_op_with_arg(): 126 class FackOpNet(nn.Cell): 127 def __init__(self): 128 super(FackOpNet, self).__init__() 129 self.op = FakeOp() 130 131 def construct(self, x, y): 132 return self.op(x, y) 133 134 class OpNet(nn.Cell): 135 def __init__(self): 136 super(OpNet, self).__init__() 137 138 def construct(self, op, x, y): 139 return op(x, y) 140 141 class OpsNet(nn.Cell): 142 def __init__(self, net): 143 super(OpsNet, self).__init__() 144 self.opnet = net 145 self.op = FackOpNet() 146 147 def construct(self, x, y): 148 op = self.op 149 a = self.opnet(op, x, y) 150 b = self.opnet(op, y, x) 151 return (a, b) 152 153 t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32)) 154 t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32)) 155 net = OpsNet(OpNet()) 156 net(t1, t2) 157 158 159def test_conv2d_op_with_arg_same_input(): 160 class FackOpNet(nn.Cell): 161 def __init__(self): 162 super(FackOpNet, self).__init__() 163 self.op = FakeOp() 164 165 def construct(self, x, y): 166 return self.op(x, y) 167 168 class OpNet(nn.Cell): 169 def __init__(self): 170 super(OpNet, self).__init__() 171 172 def construct(self, op, x, y): 173 return op(x, y) 174 175 class OpsNet(nn.Cell): 176 def __init__(self, net): 177 super(OpsNet, self).__init__() 178 self.opnet = net 179 self.op = FackOpNet() 180 181 def construct(self, x, y): 182 op = self.op 183 a = self.opnet(op, x, x) 184 b = self.opnet(op, y, x) 185 return (a, b) 186 187 t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32)) 188 t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32)) 189 net = OpsNet(OpNet()) 190 net(t1, t2) 191 192 193# test op with partial 194def test_op_as_partial(): 195 class OpAsPartial(nn.Cell): 196 def __init__(self): 197 super(OpAsPartial, self).__init__() 198 self.op = FakeOp() 199 200 def construct(self, x, y, z): 201 partial_op = F.partial(self.op, x) 202 a = partial_op(y) 203 b = partial_op(z) 204 return a, b 205 206 t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32)) 207 t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32)) 208 t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32)) 209 net = OpAsPartial() 210 net(t1, t2, t3) 211 212 213# test op with partial 214def test_op_as_partial_inside(): 215 class OpAsPartial(nn.Cell): 216 def __init__(self): 217 super(OpAsPartial, self).__init__() 218 self.op = FakeOp() 219 220 def construct(self, x, y, z): 221 partial_op = F.partial(self.op, x) 222 a = partial_op(y) 223 b = partial_op(z) 224 return a, b 225 226 class OuterNet(nn.Cell): 227 def __init__(self): 228 super(OuterNet, self).__init__() 229 self.net = OpAsPartial() 230 231 def construct(self, x, y, z): 232 a, b = self.net(x, y, z) 233 return a, b 234 235 t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32)) 236 t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32)) 237 t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32)) 238 net = OuterNet() 239 net(t1, t2, t3) 240 241 242# test op with partial case 2 243def test_op_as_partial_independent(): 244 class OpAsPartial(nn.Cell): 245 def __init__(self): 246 super(OpAsPartial, self).__init__() 247 self.op = FakeOp() 248 249 def construct(self, x, y, z): 250 partial_op1 = F.partial(self.op, x) 251 a = partial_op1(y) 252 partial_op2 = F.partial(self.op, x) 253 b = partial_op2(z) 254 return a, b 255 256 t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32)) 257 t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32)) 258 t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32)) 259 net = OpAsPartial() 260 net(t1, t2, t3) 261 262 263def test_nest_partial(): 264 class NestPartial(nn.Cell): 265 def __init__(self): 266 super(NestPartial, self).__init__() 267 self.op = FakeOp() 268 269 def construct(self, x, y, z): 270 partial_op1 = F.partial(self.op) 271 partial_op2 = F.partial(partial_op1, x) 272 a = partial_op2(y) 273 partial_op3 = F.partial(self.op) 274 partial_op4 = F.partial(partial_op3, x) 275 b = partial_op4(z) 276 return a, b 277 278 t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32)) 279 t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32)) 280 t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32)) 281 net = NestPartial() 282 net(t1, t2, t3) 283 284 285# high order argument 286# op and op args as network arguments 287def test_op_with_arg_as_input(): 288 class WithOpArgNet(nn.Cell): 289 def __init__(self): 290 super(WithOpArgNet, self).__init__() 291 292 def construct(self, op, x, y): 293 return op(x, y) 294 295 class OpsNet(nn.Cell): 296 def __init__(self, net): 297 super(OpsNet, self).__init__() 298 self.opnet = net 299 self.op = FakeOp() 300 301 def construct(self, x, y, z): 302 op = self.op 303 a = self.opnet(op, x, z) 304 b = self.opnet(op, x, y) 305 return (a, b) 306 307 t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32)) 308 t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32)) 309 t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32)) 310 net = OpsNet(WithOpArgNet()) 311 net(t1, t2, t3) 312 313 314# The partial application used as argument is not supported yet 315# because of the limit of inference specialize system 316@pytest.mark.skip("poly in infer") 317def test_partial_as_arg(): 318 class PartialArgNet(nn.Cell): 319 def __init__(self): 320 super(PartialArgNet, self).__init__() 321 322 def construct(self, partial_op, y): 323 return partial_op(y) 324 325 class OpsNet(nn.Cell): 326 def __init__(self, net): 327 super(OpsNet, self).__init__() 328 self.partial_net = net 329 self.op = FakeOp() 330 331 def construct(self, x, y, z): 332 partial_op = F.partial(self.op, x) 333 a = self.partial_net(partial_op, z) 334 b = self.partial_net(partial_op, y) 335 return (a, b) 336 337 t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32)) 338 t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32)) 339 t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32)) 340 net = OpsNet(PartialArgNet()) 341 net(t1, t2, t3) 342