1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16import mindspore as ms 17import mindspore.nn as nn 18from mindspore import Tensor 19from mindspore import context 20from mindspore.common.api import _cell_graph_executor 21from mindspore.ops import composite as C 22from mindspore.ops import operations as P 23from tests.ut.python.ops.test_math_ops import VirtualLoss 24 25 26grad_all = C.GradOperation(get_all=True) 27 28 29class NetWithLoss(nn.Cell): 30 def __init__(self, network): 31 super(NetWithLoss, self).__init__() 32 self.loss = VirtualLoss() 33 self.network = network 34 35 def construct(self, x, y): 36 predict = self.network(x, y) 37 return self.loss(predict) 38 39 40class GradWrap(nn.Cell): 41 def __init__(self, network): 42 super(GradWrap, self).__init__() 43 self.network = network 44 45 def construct(self, x, y): 46 return grad_all(self.network)(x, y) 47 48 49class Net(nn.Cell): 50 def __init__(self, axis=0, strategy1=None, strategy2=None, shape=None, target=""): 51 super().__init__() 52 if shape is None: 53 shape = [64, 64] 54 self.gatherv2 = P.Gather().shard(strategy1).add_prim_attr("primitive_target", target) 55 self.mul = P.Mul().shard(strategy2) 56 self.index = Tensor(np.ones(shape), dtype=ms.int32) 57 self.axis = axis 58 59 def construct(self, x, y): 60 out = self.gatherv2(x, self.index, self.axis) 61 out = self.mul(out, y) 62 return out 63 64 65def test_gatherv2_semi_auto0(): 66 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") 67 strategy1 = ((1, 8), (1, 1)) 68 strategy2 = ((4, 2, 1), (4, 2, 1)) 69 net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) 70 net.set_auto_parallel() 71 72 x = Tensor(np.ones([64, 64]), dtype=ms.float32) 73 y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) 74 net.set_train() 75 _cell_graph_executor.compile(net, x, y) 76 77 78def test_gatherv2_semi_auto1(): 79 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") 80 strategy1 = ((8, 1), (1, 1)) 81 strategy2 = ((4, 2, 1), (4, 2, 1)) 82 net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) 83 net.set_auto_parallel() 84 85 x = Tensor(np.ones([64, 64]), dtype=ms.float32) 86 y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) 87 net.set_train() 88 _cell_graph_executor.compile(net, x, y) 89 90 91def test_gatherv2_semi_auto2(): 92 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") 93 strategy1 = ((2, 4), (1, 1)) 94 strategy2 = ((4, 2, 1), (4, 2, 1)) 95 net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) 96 net.set_auto_parallel() 97 98 x = Tensor(np.ones([64, 64]), dtype=ms.float32) 99 y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) 100 net.set_train() 101 _cell_graph_executor.compile(net, x, y) 102 103 104def test_gatherv2_semi_auto3(): 105 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") 106 strategy1 = ((1, 8), (1, 1)) 107 strategy2 = ((4, 2, 1), (4, 2, 1)) 108 net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2))) 109 net.set_auto_parallel() 110 111 x = Tensor(np.ones([64, 64]), dtype=ms.float32) 112 y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) 113 net.set_train() 114 _cell_graph_executor.compile(net, x, y) 115 116 117def test_gatherv2_semi_auto4(): 118 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") 119 strategy1 = ((8, 1), (1, 1)) 120 strategy2 = ((4, 2, 1), (4, 2, 1)) 121 net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2))) 122 net.set_auto_parallel() 123 124 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 125 y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) 126 net.set_train() 127 _cell_graph_executor.compile(net, x, y) 128 129 130def test_gatherv2_semi_auto5(): 131 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") 132 strategy1 = ((2, 4), (1, 1)) 133 strategy2 = ((4, 2, 1), (4, 2, 1)) 134 net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2))) 135 net.set_auto_parallel() 136 137 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 138 y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) 139 net.set_train() 140 _cell_graph_executor.compile(net, x, y) 141 142 143def test_gatherv2_semi_auto6(): 144 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") 145 strategy2 = ((4, 2, 1), (4, 2, 1)) 146 net = GradWrap(NetWithLoss(Net(0, None, strategy2))) 147 net.set_auto_parallel() 148 149 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 150 y = Tensor(np.ones([64, 64, 32]), dtype=ms.float32) 151 net.set_train() 152 _cell_graph_executor.compile(net, x, y) 153 154 155def test_gatherv2_semi_auto7(): 156 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") 157 strategy2 = ((4, 2, 1), (4, 2, 1)) 158 net = GradWrap(NetWithLoss(Net(1, None, strategy2))) 159 net.set_auto_parallel() 160 161 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 162 y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) 163 net.set_train() 164 _cell_graph_executor.compile(net, x, y) 165 166 167def test_gatherv2_semi_auto8(): 168 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") 169 strategy1 = ((8,), (1, 1)) 170 strategy2 = ((4, 2), (4, 2)) 171 net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) 172 net.set_auto_parallel() 173 174 x = Tensor(np.ones([64]), dtype=ms.float32) 175 y = Tensor(np.ones([64, 64]), dtype=ms.float32) 176 net.set_train() 177 _cell_graph_executor.compile(net, x, y) 178 179 180def test_gatherv2_forward_all_reduce(): 181 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") 182 strategy1 = ((8, 1), (1, 1)) 183 strategy2 = ((2, 4, 1), (2, 4, 1)) 184 net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2, shape=[2, 64]))) 185 net.set_auto_parallel() 186 187 x = Tensor(np.ones([64, 64]), dtype=ms.float32) 188 y = Tensor(np.ones([2, 64, 64]), dtype=ms.float32) 189 net.set_train() 190 _cell_graph_executor.compile(net, x, y) 191 192 193def test_gatherv2_shard_batch_and_axis(): 194 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") 195 strategy1 = ((4, 1), (2, 1)) 196 strategy2 = ((2, 4, 1), (2, 4, 1)) 197 net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2, shape=[2, 64]))) 198 net.set_auto_parallel() 199 200 x = Tensor(np.ones([64, 64]), dtype=ms.float32) 201 y = Tensor(np.ones([2, 64, 64]), dtype=ms.float32) 202 net.set_train() 203 _cell_graph_executor.compile(net, x, y) 204 205 206def test_gatherv2_split_axis_0_repeat_calc(): 207 context.set_auto_parallel_context(device_num=8, global_rank=7, parallel_mode="semi_auto_parallel") 208 strategy1 = ((4, 1), (1, 1)) 209 strategy2 = ((2, 4, 1), (2, 4, 1)) 210 net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2, shape=[2, 64]))) 211 net.set_auto_parallel() 212 213 x = Tensor(np.ones([64, 64]), dtype=ms.float32) 214 y = Tensor(np.ones([2, 64, 64]), dtype=ms.float32) 215 net.set_train() 216 _cell_graph_executor.compile(net, x, y) 217 218 219def test_gatherv2_auto0(): 220 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel") 221 net = GradWrap(NetWithLoss(Net(0))) 222 net.set_auto_parallel() 223 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 224 y = Tensor(np.ones([64, 64, 32]), dtype=ms.float32) 225 net.set_train() 226 _cell_graph_executor.compile(net, x, y) 227 228 229def test_gatherv2_auto1(): 230 context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel") 231 net = GradWrap(NetWithLoss(Net(1))) 232 net.set_auto_parallel() 233 x = Tensor(np.ones([64, 32]), dtype=ms.float32) 234 y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) 235 net.set_train() 236 _cell_graph_executor.compile(net, x, y) 237