1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import re 16import numpy as np 17 18import mindspore as ms 19import mindspore.nn as nn 20from mindspore import Tensor 21from mindspore import context 22from mindspore.common.api import _cell_graph_executor 23from mindspore.ops import composite as C 24from mindspore.ops import operations as P 25from mindspore.parallel._utils import _reset_op_id as reset_op_id 26from tests.ut.python.ops.test_math_ops import VirtualLoss 27 28context.set_context(mode=context.GRAPH_MODE) 29 30 31grad_all = C.GradOperation(get_all=True) 32 33 34class NetWithLoss(nn.Cell): 35 def __init__(self, network): 36 super(NetWithLoss, self).__init__() 37 self.loss = VirtualLoss() 38 self.network = network 39 40 def construct(self, x, y, b): 41 predict = self.network(x, y, b) 42 return self.loss(predict) 43 44 45class GradWrap(nn.Cell): 46 def __init__(self, network): 47 super(GradWrap, self).__init__() 48 self.network = network 49 50 def construct(self, x, y, b): 51 return grad_all(self.network)(x, y, b) 52 53 54def compile_net(net, x, y, b, phase): 55 net.set_auto_parallel() 56 net.set_train() 57 _cell_graph_executor.compile(net, x, y, b, phase=phase) 58 59 60def test_auto_parallel_arithmetic(): 61 class Net(nn.Cell): 62 def __init__(self): 63 super().__init__() 64 self.matmul = P.MatMul() 65 self.floordiv = P.FloorDiv() 66 67 def construct(self, x, y, b): 68 out = self.matmul(x, y) 69 out = self.floordiv(out, b) 70 return out 71 72 context.set_auto_parallel_context(device_num=8, global_rank=0) 73 net = NetWithLoss(Net()) 74 context.set_auto_parallel_context(parallel_mode="auto_parallel") 75 reset_op_id() 76 77 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 78 y = Tensor(np.ones([32, 128]), dtype=ms.float32) 79 b = Tensor(np.ones([64, 128]), dtype=ms.float32) 80 compile_net(net, x, y, b, phase='train') 81 strategies = _cell_graph_executor._get_shard_strategy(net) 82 for (k, v) in strategies.items(): 83 if re.search('FloorDiv-op', k) is not None: 84 assert v == [[2, 4], [2, 4]] 85 elif re.search('MatMul-op', k) is not None: 86 assert v == [[2, 1], [1, 4]] 87 88 89def test_auto_parallel_arithmetic_broadcast_both(): 90 class Net(nn.Cell): 91 def __init__(self): 92 super().__init__() 93 self.matmul = P.MatMul() 94 self.floordiv = P.FloorDiv() 95 96 def construct(self, x, y, b): 97 out = self.matmul(x, y) 98 out = self.floordiv(out, b) 99 return out 100 101 context.set_auto_parallel_context(device_num=8, global_rank=0) 102 net = NetWithLoss(Net()) 103 context.set_auto_parallel_context(parallel_mode="auto_parallel") 104 reset_op_id() 105 106 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 107 y = Tensor(np.ones([32, 1]), dtype=ms.float32) 108 b = Tensor(np.ones([1, 64]), dtype=ms.float32) 109 compile_net(net, x, y, b, phase='train') 110 strategies = _cell_graph_executor._get_shard_strategy(net) 111 for (k, v) in strategies.items(): 112 if re.search('FloorDiv-op', k) is not None: 113 assert v == [[8, 1], [1, 1]] 114 elif re.search('MatMul-op', k) is not None: 115 assert v == [[8, 1], [1, 1]] 116 117 118def test_auto_parallel_arithmetic_broadcast_right(): 119 class Net(nn.Cell): 120 def __init__(self): 121 super().__init__() 122 self.matmul = P.MatMul() 123 self.floordiv = P.FloorDiv() 124 125 def construct(self, x, y, b): 126 out = self.matmul(x, y) 127 out = self.floordiv(out, b) 128 return out 129 130 context.set_auto_parallel_context(device_num=8, global_rank=0) 131 net = NetWithLoss(Net()) 132 context.set_auto_parallel_context(parallel_mode="auto_parallel") 133 reset_op_id() 134 135 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 136 y = Tensor(np.ones([32, 32]), dtype=ms.float32) 137 b = Tensor(np.ones([32]), dtype=ms.float32) 138 compile_net(net, x, y, b, phase='train') 139 strategies = _cell_graph_executor._get_shard_strategy(net) 140 for (k, v) in strategies.items(): 141 if re.search('FloorDiv-op', k) is not None: 142 assert v == [[4, 2], [2]] 143 elif re.search('MatMul-op', k) is not None: 144 assert v == [[4, 1], [1, 2]] 145 146 147def test_auto_parallel_arithmetic_broadcast_left(): 148 class Net(nn.Cell): 149 def __init__(self): 150 super().__init__() 151 self.matmul = P.MatMul() 152 self.floordiv = P.FloorDiv() 153 154 def construct(self, x, y, b): 155 out = self.matmul(x, y) 156 out = self.floordiv(out, b) 157 return out 158 159 context.set_auto_parallel_context(device_num=8, global_rank=0) 160 net = NetWithLoss(Net()) 161 context.set_auto_parallel_context(parallel_mode="auto_parallel") 162 reset_op_id() 163 164 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 165 y = Tensor(np.ones([32, 32]), dtype=ms.float32) 166 b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) 167 compile_net(net, x, y, b, phase="train") 168 strategies = _cell_graph_executor._get_shard_strategy(net) 169 for (k, v) in strategies.items(): 170 if re.search('FloorDiv-op', k) is not None: 171 assert v == [[4, 2], [1, 4, 2]] 172 elif re.search('MatMul-op', k) is not None: 173 assert v == [[4, 1], [1, 2]] 174