1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16 17import mindspore as ms 18import mindspore.nn as nn 19from mindspore import Tensor 20from mindspore import context 21from mindspore.common.api import _cell_graph_executor 22from mindspore.ops import composite as C 23from mindspore.ops import operations as P 24from tests.ut.python.ops.test_math_ops import VirtualLoss 25 26 27grad_all = C.GradOperation(get_all=True) 28 29 30class NetWithLoss(nn.Cell): 31 def __init__(self, network): 32 super(NetWithLoss, self).__init__() 33 self.loss = VirtualLoss() 34 self.network = network 35 36 def construct(self, x, y, b): 37 predict = self.network(x, y, b) 38 return self.loss(predict) 39 40 41class GradWrap(nn.Cell): 42 def __init__(self, network): 43 super(GradWrap, self).__init__() 44 self.network = network 45 46 def construct(self, x, y, b): 47 return grad_all(self.network)(x, y, b) 48 49 50def compile_net(net, x, y, b): 51 net.set_auto_parallel() 52 net.set_train() 53 _cell_graph_executor.compile(net, x, y, b) 54 55 56def test_matmul_equal(): 57 class Net(nn.Cell): 58 def __init__(self, strategy1, strategy2): 59 super().__init__() 60 self.matmul = P.MatMul().shard(strategy1) 61 self.equal = P.Equal().shard(strategy2) 62 63 def construct(self, x, y, b): 64 out = self.matmul(x, y) 65 out = self.equal(out, b) 66 return out 67 68 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") 69 strategy1 = ((2, 2), (2, 2)) 70 strategy2 = ((4, 2), (4, 2)) 71 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 72 73 x = Tensor(np.ones([128, 32]), dtype=ms.float32) 74 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 75 b = Tensor(np.ones([128, 64]), dtype=ms.float32) 76 compile_net(net, x, y, b) 77 78 79def test_matmul_not_equal(): 80 class Net(nn.Cell): 81 def __init__(self, strategy1, strategy2): 82 super().__init__() 83 self.matmul = P.MatMul().shard(strategy1) 84 self.notequal = P.NotEqual().shard(strategy2) 85 86 def construct(self, x, y, b): 87 out = self.matmul(x, y) 88 out = self.notequal(out, b) 89 return out 90 91 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") 92 strategy1 = ((2, 2), (2, 2)) 93 strategy2 = ((4, 2), (4, 2)) 94 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 95 96 x = Tensor(np.ones([128, 32]), dtype=ms.float32) 97 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 98 b = Tensor(np.ones([128, 64]), dtype=ms.float32) 99 compile_net(net, x, y, b) 100 101 102def test_matmul_approximateEqual(): 103 class Net(nn.Cell): 104 def __init__(self, strategy1, strategy2): 105 super().__init__() 106 self.matmul = P.MatMul().shard(strategy1) 107 self.approximateEqual = P.ApproximateEqual(tolerance=0.5).shard(strategy2) 108 109 def construct(self, x, y, b): 110 out = self.matmul(x, y) 111 out = self.approximateEqual(out, b) 112 return out 113 114 context.set_auto_parallel_context(device_num=8, global_rank=0) 115 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 116 strategy1 = ((2, 2), (2, 2)) 117 strategy2 = ((4, 2), (4, 2)) 118 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 119 120 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 121 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 122 b = Tensor(np.ones([64, 64]), dtype=ms.float32) 123 compile_net(net, x, y, b) 124 125 126def test_matmul_greater(): 127 class Net(nn.Cell): 128 def __init__(self, strategy1, strategy2): 129 super().__init__() 130 self.matmul = P.MatMul().shard(strategy1) 131 self.greater = P.Greater().shard(strategy2) 132 133 def construct(self, x, y, b): 134 out = self.matmul(x, y) 135 out = self.greater(out, b) 136 return out 137 138 context.set_auto_parallel_context(device_num=8, global_rank=0) 139 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 140 strategy1 = ((2, 2), (2, 2)) 141 strategy2 = ((4, 2), (4, 2)) 142 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 143 144 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 145 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 146 b = Tensor(np.ones([64, 64]), dtype=ms.float32) 147 compile_net(net, x, y, b) 148 149 150def test_matmul_greaterEqual(): 151 class Net(nn.Cell): 152 def __init__(self, strategy1, strategy2): 153 super().__init__() 154 self.matmul = P.MatMul().shard(strategy1) 155 self.greaterEqual = P.GreaterEqual().shard(strategy2) 156 157 def construct(self, x, y, b): 158 out = self.matmul(x, y) 159 out = self.greaterEqual(out, b) 160 return out 161 162 context.set_auto_parallel_context(device_num=8, global_rank=0) 163 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 164 strategy1 = ((2, 2), (2, 2)) 165 strategy2 = ((4, 2), (4, 2)) 166 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 167 168 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 169 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 170 b = Tensor(np.ones([64, 64]), dtype=ms.float32) 171 compile_net(net, x, y, b) 172 173 174def test_matmul_less(): 175 class Net(nn.Cell): 176 def __init__(self, strategy1, strategy2): 177 super().__init__() 178 self.matmul = P.MatMul().shard(strategy1) 179 self.less = P.Less().shard(strategy2) 180 181 def construct(self, x, y, b): 182 out = self.matmul(x, y) 183 out = self.less(out, b) 184 return out 185 186 context.set_auto_parallel_context(device_num=8, global_rank=0) 187 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 188 strategy1 = ((2, 2), (2, 2)) 189 strategy2 = ((4, 2), (4, 2)) 190 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 191 192 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 193 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 194 b = Tensor(np.ones([64, 64]), dtype=ms.float32) 195 compile_net(net, x, y, b) 196 197 198def test_matmul_lessEqual(): 199 class Net(nn.Cell): 200 def __init__(self, strategy1, strategy2): 201 super().__init__() 202 self.matmul = P.MatMul().shard(strategy1) 203 self.lessEqual = P.LessEqual().shard(strategy2) 204 205 def construct(self, x, y, b): 206 out = self.matmul(x, y) 207 out = self.lessEqual(out, b) 208 return out 209 210 context.set_auto_parallel_context(device_num=8, global_rank=0) 211 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 212 strategy1 = ((2, 2), (2, 2)) 213 strategy2 = ((4, 2), (4, 2)) 214 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 215 216 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 217 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 218 b = Tensor(np.ones([64, 64]), dtype=ms.float32) 219 compile_net(net, x, y, b) 220 221 222def test_matmul_not_equal_repeated_calculation(): 223 class Net(nn.Cell): 224 def __init__(self, strategy1, strategy2): 225 super().__init__() 226 self.matmul = P.MatMul().shard(strategy1) 227 self.notequal = P.NotEqual().shard(strategy2) 228 229 def construct(self, x, y, b): 230 out = self.matmul(x, y) 231 out = self.notequal(out, b) 232 return out 233 234 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") 235 strategy1 = ((2, 2), (2, 2)) 236 strategy2 = ((4, 1), (4, 1)) 237 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 238 239 x = Tensor(np.ones([128, 32]), dtype=ms.float32) 240 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 241 b = Tensor(np.ones([128, 64]), dtype=ms.float32) 242 compile_net(net, x, y, b) 243 244 245def test_matmul_maximum(): 246 class Net(nn.Cell): 247 def __init__(self, strategy1, strategy2): 248 super().__init__() 249 self.matmul = P.MatMul().shard(strategy1) 250 self.maximum = P.Maximum().shard(strategy2) 251 252 def construct(self, x, y, b): 253 out = self.matmul(x, y) 254 out = self.maximum(out, b) 255 return out 256 257 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") 258 strategy1 = ((2, 2), (2, 2)) 259 strategy2 = ((4, 2), (4, 2)) 260 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 261 262 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 263 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 264 b = Tensor(np.ones([64, 64]), dtype=ms.float32) 265 compile_net(net, x, y, b) 266 267 268def test_matmul_maximum_broadcast(): 269 class Net(nn.Cell): 270 def __init__(self, strategy1, strategy2): 271 super().__init__() 272 self.matmul = P.MatMul().shard(strategy1) 273 self.maximum = P.Maximum().shard(strategy2) 274 275 def construct(self, x, y, b): 276 out = self.matmul(x, y) 277 out = self.maximum(out, b) 278 return out 279 280 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") 281 strategy1 = ((2, 2), (2, 2)) 282 strategy2 = ((4, 2), (2,)) 283 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 284 285 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 286 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 287 b = Tensor(np.ones([64]), dtype=ms.float32) 288 compile_net(net, x, y, b) 289 290 291def test_matmul_maximum_broadcast2(): 292 class Net(nn.Cell): 293 def __init__(self, strategy1, strategy2): 294 super().__init__() 295 self.matmul = P.MatMul().shard(strategy1) 296 self.maximum = P.Maximum().shard(strategy2) 297 298 def construct(self, x, y, b): 299 out = self.matmul(x, y) 300 out = self.maximum(out, b) 301 return out 302 303 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") 304 strategy1 = ((2, 4), (4, 1)) 305 strategy2 = ((4, 1), (1, 2)) 306 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 307 308 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 309 y = Tensor(np.ones([32, 1]), dtype=ms.float32) 310 b = Tensor(np.ones([1, 64]), dtype=ms.float32) 311 compile_net(net, x, y, b) 312 313 314def test_matmul_minimum(): 315 class Net(nn.Cell): 316 def __init__(self, strategy1, strategy2): 317 super().__init__() 318 self.matmul = P.MatMul().shard(strategy1) 319 self.minimum = P.Minimum().shard(strategy2) 320 321 def construct(self, x, y, b): 322 out = self.matmul(x, y) 323 out = self.minimum(out, b) 324 return out 325 326 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") 327 strategy1 = ((2, 2), (2, 2)) 328 strategy2 = ((4, 2), (4, 2)) 329 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 330 331 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 332 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 333 b = Tensor(np.ones([64, 64]), dtype=ms.float32) 334 compile_net(net, x, y, b) 335 336 337def test_matmul_minimum_broadcast(): 338 class Net(nn.Cell): 339 def __init__(self, strategy1, strategy2): 340 super().__init__() 341 self.matmul = P.MatMul().shard(strategy1) 342 self.minimum = P.Maximum().shard(strategy2) 343 344 def construct(self, x, y, b): 345 out = self.matmul(x, y) 346 out = self.minimum(out, b) 347 return out 348 349 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") 350 strategy1 = ((2, 2), (2, 2)) 351 strategy2 = ((4, 2), (2,)) 352 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 353 354 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 355 y = Tensor(np.ones([32, 64]), dtype=ms.float32) 356 b = Tensor(np.ones([64]), dtype=ms.float32) 357 compile_net(net, x, y, b) 358 359 360def test_matmul_minimum_broadcast2(): 361 class Net(nn.Cell): 362 def __init__(self, strategy1, strategy2): 363 super().__init__() 364 self.matmul = P.MatMul().shard(strategy1) 365 self.minimum = P.Minimum().shard(strategy2) 366 367 def construct(self, x, y, b): 368 out = self.matmul(x, y) 369 out = self.minimum(out, b) 370 return out 371 372 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") 373 strategy1 = ((2, 4), (4, 1)) 374 strategy2 = ((4, 1), (1, 2)) 375 net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) 376 377 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 378 y = Tensor(np.ones([32, 1]), dtype=ms.float32) 379 b = Tensor(np.ones([1, 64]), dtype=ms.float32) 380 compile_net(net, x, y, b) 381 382 383def test_matmul_minimum_auto_parallel(): 384 class Net(nn.Cell): 385 def __init__(self): 386 super().__init__() 387 self.matmul = P.MatMul() 388 self.minimum = P.Minimum() 389 390 def construct(self, x, y, b): 391 out = self.matmul(x, y) 392 out = self.minimum(out, b) 393 return out 394 395 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel") 396 net = GradWrap(NetWithLoss(Net())) 397 398 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 399 y = Tensor(np.ones([32, 1]), dtype=ms.float32) 400 b = Tensor(np.ones([1, 64]), dtype=ms.float32) 401 compile_net(net, x, y, b) 402