1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16 17import mindspore as ms 18import mindspore.nn as nn 19from mindspore import Tensor 20from mindspore import context 21from mindspore.common.api import _cell_graph_executor 22from mindspore.common.parameter import Parameter 23from mindspore.ops import composite as C 24from mindspore.ops import operations as P 25from tests.ut.python.ops.test_math_ops import VirtualLoss 26 27 28grad_all = C.GradOperation(get_all=True) 29 30 31class NetWithLoss(nn.Cell): 32 def __init__(self, network): 33 super(NetWithLoss, self).__init__() 34 self.loss = VirtualLoss() 35 self.network = network 36 37 def construct(self, x): 38 predict = self.network(x) 39 return self.loss(predict) 40 41 42class GradWrap(nn.Cell): 43 def __init__(self, network): 44 super(GradWrap, self).__init__() 45 self.network = network 46 47 def construct(self, x): 48 return grad_all(self.network)(x) 49 50 51def test_reshape_matmul(): 52 class Net(nn.Cell): 53 def __init__(self): 54 super().__init__() 55 self.reshape = P.Reshape() 56 self.matmul = P.MatMul() 57 self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight") 58 59 def construct(self, x): 60 out = self.reshape(x, (64, 28)) 61 out = self.matmul(out, self.matmul_weight) 62 return out 63 64 size = 8 65 context.set_auto_parallel_context(device_num=size, global_rank=0) 66 x = Tensor(np.ones([8 * size, 28, 1, 1]), dtype=ms.float32) 67 68 net = GradWrap(NetWithLoss(Net())) 69 context.set_auto_parallel_context(parallel_mode="auto_parallel") 70 net.set_auto_parallel() 71 net.set_train() 72 _cell_graph_executor.compile(net, x) 73 74def test_reshape_reshape(): 75 class Net(nn.Cell): 76 def __init__(self): 77 super().__init__() 78 self.reshape = P.Reshape() 79 self.relu = P.ReLU() 80 81 def construct(self, x): 82 x = self.relu(x) 83 out = self.reshape(x, (64, 28)) 84 out = self.reshape(out, (64, 28, 1)) 85 return out 86 87 size = 8 88 context.set_auto_parallel_context(device_num=size, global_rank=0) 89 x = Tensor(np.ones([8 * size, 28, 1, 1]), dtype=ms.float32) 90 91 net = GradWrap(NetWithLoss(Net())) 92 context.set_auto_parallel_context(parallel_mode="auto_parallel") 93 net.set_auto_parallel() 94 net.set_train() 95 _cell_graph_executor.compile(net, x) 96 97 98def test_reshape_auto_1(): 99 class Net(nn.Cell): 100 def __init__(self): 101 super().__init__() 102 self.relu = P.ReLU() 103 self.reshape = P.Reshape() 104 self.matmul = P.MatMul() 105 self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight") 106 107 def construct(self, x): 108 out = self.relu(x) 109 out = self.reshape(out, (64, 28)) 110 out = self.matmul(out, self.matmul_weight) 111 return out 112 113 size = 8 114 context.set_auto_parallel_context(device_num=size, global_rank=0) 115 x = Tensor(np.ones([8 * size, 28, 1, 1]), dtype=ms.float32) 116 117 net = GradWrap(NetWithLoss(Net())) 118 context.set_auto_parallel_context(parallel_mode="auto_parallel") 119 net.set_auto_parallel() 120 net.set_train() 121 _cell_graph_executor.compile(net, x) 122 123 124def test_reshape_auto_2(): 125 class Net(nn.Cell): 126 def __init__(self): 127 super().__init__() 128 self.relu = P.ReLU() 129 self.reshape = P.Reshape() 130 self.matmul = P.MatMul() 131 self.add_weight = Parameter(Tensor(np.ones([128, 32]), dtype=ms.float32), name="weight1") 132 self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight") 133 134 def construct(self, x): 135 out = self.relu(x) 136 out = self.reshape(out, (64, 28)) 137 out = self.matmul(out, self.matmul_weight) 138 out = self.reshape(out, (128, 32)) 139 out = out + self.add_weight 140 return out 141 142 size = 8 143 context.set_auto_parallel_context(device_num=size, global_rank=0) 144 x = Tensor(np.ones([8 * size, 28, 1, 1]), dtype=ms.float32) 145 146 net = GradWrap(NetWithLoss(Net())) 147 context.set_auto_parallel_context(parallel_mode="auto_parallel") 148 net.set_auto_parallel() 149 net.set_train() 150 _cell_graph_executor.compile(net, x) 151 152 153def test_reshape_auto_3(): 154 class Net(nn.Cell): 155 def __init__(self): 156 super().__init__() 157 self.relu = P.ReLU() 158 self.reshape = P.Reshape() 159 self.matmul = P.MatMul() 160 self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight") 161 162 def construct(self, x): 163 out = self.relu(x) 164 out = self.matmul(out, self.matmul_weight) 165 out = self.reshape(out, (8, 8, 8, 8)) 166 return out 167 168 size = 8 169 context.set_auto_parallel_context(device_num=size, global_rank=0) 170 x = Tensor(np.ones([8 * size, 28]), dtype=ms.float32) 171 172 net = GradWrap(NetWithLoss(Net())) 173 context.set_auto_parallel_context(parallel_mode="auto_parallel") 174 net.set_auto_parallel() 175 net.set_train() 176 _cell_graph_executor.compile(net, x) 177 178 179def test_reshape_auto_4(): 180 class Net(nn.Cell): 181 def __init__(self): 182 super().__init__() 183 self.relu = P.ReLU() 184 self.reshape = P.Reshape() 185 self.matmul = P.MatMul() 186 self.matmul_weight = Parameter(Tensor(np.ones([28 * 64]), dtype=ms.float32), name="weight") 187 188 def construct(self, x): 189 out = self.relu(x) 190 out = self.reshape(out, (64, 28)) 191 w = self.reshape(self.matmul_weight, (28, 64)) 192 out = self.matmul(out, w) 193 return out 194 195 size = 8 196 context.set_auto_parallel_context(device_num=size, global_rank=0) 197 x = Tensor(np.ones([8 * size, 28, 1, 1]), dtype=ms.float32) 198 199 net = GradWrap(NetWithLoss(Net())) 200 context.set_auto_parallel_context(parallel_mode="auto_parallel") 201 net.set_auto_parallel() 202 net.set_train() 203 _cell_graph_executor.compile(net, x) 204 205 206def test_reshape_auto_5(): 207 class NetWithLoss5(nn.Cell): 208 def __init__(self, network): 209 super(NetWithLoss5, self).__init__() 210 self.loss = VirtualLoss() 211 self.network = network 212 213 def construct(self, x, y): 214 predict = self.network(x, y) 215 return self.loss(predict) 216 217 class GradWrap5(nn.Cell): 218 def __init__(self, network): 219 super(GradWrap5, self).__init__() 220 self.network = network 221 222 def construct(self, x, y): 223 return grad_all(self.network)(x, y) 224 225 class Net(nn.Cell): 226 def __init__(self): 227 super().__init__() 228 self.relu = P.ReLU() 229 self.mul = P.Mul() 230 self.reshape = P.Reshape() 231 self.reduce_sum = P.ReduceSum() 232 self.wide_w = Parameter(Tensor(np.ones([4, 1024 * 8, 64]), dtype=ms.float32), name="weight") 233 234 def construct(self, x, y): 235 mask = self.reshape(y, (4, 1024 * 8, 1)) 236 w_id = self.relu(x) 237 wx = self.mul(w_id, mask) 238 wide_out = self.reshape(self.reduce_sum(wx, 1), (-1, 1)) 239 deep_id = x + self.wide_w 240 vx = self.mul(deep_id, mask) 241 deep_in = self.reshape(vx, (-1, 1024 * 8 * 64)) 242 out = wide_out + deep_in 243 return out 244 245 size = 8 246 context.set_auto_parallel_context(device_num=size, global_rank=0) 247 x = Tensor(np.ones([4, 1024 * size, 1]), dtype=ms.float32) 248 y = Tensor(np.ones([4, 1024 * size,]), dtype=ms.float32) 249 250 net = GradWrap5(NetWithLoss5(Net())) 251 context.set_auto_parallel_context(parallel_mode="auto_parallel") 252 net.set_auto_parallel() 253 net.set_train() 254 _cell_graph_executor.compile(net, x, y) 255 256def test_reshape_auto_6(): 257 class NetWithLoss6(nn.Cell): 258 def __init__(self, network): 259 super(NetWithLoss6, self).__init__() 260 self.loss = VirtualLoss() 261 self.network = network 262 263 def construct(self, x, y): 264 predict = self.network(x, y) 265 return self.loss(predict) 266 267 class GradWrap6(nn.Cell): 268 def __init__(self, network): 269 super(GradWrap6, self).__init__() 270 self.network = network 271 272 def construct(self, x, y): 273 return grad_all(self.network)(x, y) 274 275 class Net(nn.Cell): 276 def __init__(self): 277 super().__init__() 278 self.relu = P.ReLU() 279 self.mul = P.Mul() 280 self.reshape = P.Reshape() 281 self.reduce_mean = P.ReduceMean() 282 self.wide_w = Parameter(Tensor(np.ones([4, 1024, 1]), dtype=ms.float32), name="weight") 283 284 def construct(self, x, y): 285 out1 = x + self.wide_w 286 w = self.reshape(self.wide_w, (4, 1024)) 287 out1 = self.reduce_mean(out1, 1) 288 out1 = out1 - w 289 out2 = self.mul(y, w) 290 out = out1 + out2 291 return out 292 293 size = 8 294 context.set_auto_parallel_context(device_num=size, global_rank=0) 295 x = Tensor(np.ones([4, 1024, 1]), dtype=ms.float32) 296 y = Tensor(np.ones([4, 1024,]), dtype=ms.float32) 297 298 net = GradWrap6(NetWithLoss6(Net())) 299 context.set_auto_parallel_context(parallel_mode="auto_parallel") 300 net.set_auto_parallel() 301 net.set_train() 302 _cell_graph_executor.compile(net, x, y) 303 304def test_reshape_auto_7(): 305 class Net(nn.Cell): 306 def __init__(self): 307 super().__init__() 308 self.reshape = P.Reshape() 309 self.mul = P.Mul().shard(((1, 2, 4), (2, 4))) 310 self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight") 311 312 def construct(self, x): 313 weight = self.reshape(self.mul_weight, (1, 128, 96)) 314 out = self.mul(weight, self.mul_weight) 315 return out 316 317 size = 8 318 context.set_auto_parallel_context(device_num=size, global_rank=0) 319 x = Tensor(np.ones([128, 28]), dtype=ms.float32) 320 321 net = GradWrap(NetWithLoss(Net())) 322 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 323 net.set_auto_parallel() 324 net.set_train() 325 _cell_graph_executor.compile(net, x) 326 327def test_reshape_depend_reshape(): 328 class Net(nn.Cell): 329 def __init__(self): 330 super().__init__() 331 self.reshape1 = P.Reshape() 332 self.reshape2 = P.Reshape() 333 self.relu = P.ReLU() 334 self.depend = P.Depend() 335 self.mul = P.Mul().shard(((2, 4), (2, 4))) 336 self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight") 337 self.add = P.Add().shard(((4, 2), (4, 2))) 338 339 def construct(self, x, y): 340 out1 = self.mul(x, self.mul_weight) 341 y = self.relu(y) 342 out2 = self.reshape1(y, (96, 32, 4)) 343 out3 = self.depend(out2, out1) 344 out3 = self.reshape2(out3, (128, 96)) 345 out = out1 + out3 346 return out 347 348 class NetWithLoss1(nn.Cell): 349 def __init__(self, network): 350 super(NetWithLoss1, self).__init__() 351 self.mean = P.ReduceMean(keep_dims=False) 352 self.network = network 353 354 def construct(self, x, y): 355 predict = self.network(x, y) 356 return self.mean(predict, ()) 357 358 class GradWrap1(nn.Cell): 359 def __init__(self, network): 360 super(GradWrap1, self).__init__() 361 self.network = network 362 363 def construct(self, x, y): 364 return grad_all(self.network)(x, y) 365 366 size = 8 367 context.set_auto_parallel_context(device_num=size, global_rank=0) 368 x = Tensor(np.ones([128, 96]), dtype=ms.float32) 369 y = Tensor(np.ones([256, 48]), dtype=ms.float32) 370 net = GradWrap1(NetWithLoss1(Net())) 371 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 372 net.set_auto_parallel() 373 net.set_train() 374 _cell_graph_executor.compile(net, x, y) 375 net_auto = GradWrap1(NetWithLoss1(Net())) 376 context.set_auto_parallel_context(parallel_mode="auto_parallel") 377 net_auto.set_auto_parallel() 378 net_auto.set_train() 379 _cell_graph_executor.compile(net_auto, x, y) 380