1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16 17import mindspore as ms 18import mindspore.nn as nn 19from mindspore import Tensor, Parameter 20from mindspore import context 21from mindspore.common.api import _cell_graph_executor 22from mindspore.context import set_auto_parallel_context, reset_auto_parallel_context 23from mindspore.ops import composite as C 24from mindspore.ops import operations as P 25from tests.ut.python.ops.test_math_ops import VirtualLoss 26 27 28grad_all = C.GradOperation(get_all=True) 29 30 31# model_parallel test 32def test_six_matmul_save(): 33 class NetWithLoss(nn.Cell): 34 def __init__(self, network): 35 super(NetWithLoss, self).__init__() 36 self.loss = VirtualLoss() 37 self.network = network 38 39 def construct(self, x1, x6): 40 predict = self.network(x1, x6) 41 return self.loss(predict) 42 43 class GradWrap(nn.Cell): 44 def __init__(self, network): 45 super(GradWrap, self).__init__() 46 self.network = network 47 48 def construct(self, x1, x6): 49 return grad_all(self.network)(x1, x6) 50 51 class Net(nn.Cell): 52 def __init__(self, strategy1, strategy2, strategy3, strategy4, strategy5, strategy6): 53 super().__init__() 54 self.matmul1 = P.MatMul().shard(strategy1) 55 self.matmul2 = P.MatMul().shard(strategy2) 56 self.matmul3 = P.MatMul().shard(strategy3) 57 self.matmul4 = P.MatMul().shard(strategy4) 58 self.matmul5 = P.MatMul().shard(strategy5) 59 self.matmul6 = P.MatMul().shard(strategy6) 60 self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1") 61 self.weight2 = Parameter(Tensor(np.ones([64, 64]), dtype=ms.float32), name="weight2") 62 self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3") 63 self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4") 64 self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5") 65 self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6") 66 67 def construct(self, x1, x6): 68 out = self.matmul1(x1, self.weight1) 69 out = self.matmul2(out, self.weight2) 70 out = self.matmul3(out, self.weight3) 71 out = self.matmul4(out, self.weight4) 72 out = self.matmul5(out, self.weight5) 73 out = out + self.weight6 74 out = self.matmul6(out, x6) 75 return out 76 77 reset_auto_parallel_context() 78 set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1.ckpt", 79 group_ckpt_save_file="./group_stage1.ckpt") 80 strategy1 = ((8, 1), (1, 1)) 81 strategy2 = ((1, 8), (8, 1)) 82 strategy3 = ((2, 2), (2, 2)) 83 strategy4 = ((1, 1), (1, 8)) 84 strategy5 = ((4, 2), (2, 1)) 85 strategy6 = ((4, 1), (1, 2)) 86 net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3, strategy4, strategy5, strategy6))) 87 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 88 net.set_auto_parallel() 89 x1 = Tensor(np.ones([32, 32]), dtype=ms.float32) 90 x6 = Tensor(np.ones([128, 32]), dtype=ms.float32) 91 net.set_train() 92 _cell_graph_executor.compile(net, x1, x6) 93 94 95# remove matmul2, add matmul7 96def test_six_matmul_load(): 97 class NetWithLoss(nn.Cell): 98 def __init__(self, network): 99 super(NetWithLoss, self).__init__() 100 self.loss = VirtualLoss() 101 self.network = network 102 103 def construct(self, x1, x6, x7): 104 predict = self.network(x1, x6, x7) 105 return self.loss(predict) 106 107 class GradWrap(nn.Cell): 108 def __init__(self, network): 109 super(GradWrap, self).__init__() 110 self.network = network 111 112 def construct(self, x1, x6, x7): 113 return grad_all(self.network)(x1, x6, x7) 114 115 class Net(nn.Cell): 116 def __init__(self, strategy1, strategy3, strategy4, strategy5, strategy6, strategy7): 117 super().__init__() 118 self.matmul1 = P.MatMul().shard(strategy1) 119 self.matmul3 = P.MatMul().shard(strategy3) 120 self.matmul4 = P.MatMul().shard(strategy4) 121 self.matmul5 = P.MatMul().shard(strategy5) 122 self.matmul6 = P.MatMul().shard(strategy6) 123 self.matmul7 = P.MatMul().shard(strategy7) 124 self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1") 125 self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3") 126 self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4") 127 self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5") 128 self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6") 129 130 def construct(self, x1, x6, x7): 131 out = self.matmul1(x1, self.weight1) 132 out = self.matmul3(out, self.weight3) 133 out = self.matmul4(out, self.weight4) 134 out = self.matmul5(out, self.weight5) 135 out = out + self.weight6 136 out = self.matmul6(out, x6) 137 out = self.matmul7(out, x7) 138 return out 139 140 reset_auto_parallel_context() 141 set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1.ckpt", 142 group_ckpt_save_file="./group_stage1.ckpt") 143 strategy1 = ((8, 1), (1, 1)) 144 strategy3 = ((8, 1), (1, 1)) 145 strategy4 = ((8, 1), (1, 1)) 146 strategy5 = ((8, 1), (1, 1)) 147 strategy6 = ((8, 1), (1, 1)) 148 strategy7 = ((8, 1), (1, 1)) 149 net = GradWrap(NetWithLoss(Net(strategy1, strategy3, strategy4, strategy5, strategy6, strategy7))) 150 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 151 net.set_auto_parallel() 152 x1 = Tensor(np.ones([32, 32]), dtype=ms.float32) 153 x6 = Tensor(np.ones([128, 32]), dtype=ms.float32) 154 x7 = Tensor(np.ones([32, 32]), dtype=ms.float32) 155 net.set_train() 156 _cell_graph_executor.compile(net, x1, x6, x7) 157 158 159# model_parallel test 160def test_six_matmul_save_auto(): 161 class NetWithLoss(nn.Cell): 162 def __init__(self, network): 163 super(NetWithLoss, self).__init__() 164 self.loss = VirtualLoss() 165 self.network = network 166 167 def construct(self, x1, x6): 168 predict = self.network(x1, x6) 169 return self.loss(predict) 170 171 class GradWrap(nn.Cell): 172 def __init__(self, network): 173 super(GradWrap, self).__init__() 174 self.network = network 175 176 def construct(self, x1, x6): 177 return grad_all(self.network)(x1, x6) 178 179 class Net(nn.Cell): 180 def __init__(self): 181 super().__init__() 182 self.matmul1 = P.MatMul() 183 self.matmul2 = P.MatMul() 184 self.matmul3 = P.MatMul() 185 self.matmul4 = P.MatMul() 186 self.matmul5 = P.MatMul() 187 self.matmul6 = P.MatMul() 188 self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1") 189 self.weight2 = Parameter(Tensor(np.ones([64, 64]), dtype=ms.float32), name="weight2") 190 self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3") 191 self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4") 192 self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5") 193 self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6") 194 195 def construct(self, x1, x6): 196 out = self.matmul1(x1, self.weight1) 197 out = self.matmul2(out, self.weight2) 198 out = self.matmul3(out, self.weight3) 199 out = self.matmul4(out, self.weight4) 200 out = self.matmul5(out, self.weight5) 201 out = out + self.weight6 202 out = self.matmul6(out, x6) 203 return out 204 205 reset_auto_parallel_context() 206 set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1_auto.ckpt") 207 net = GradWrap(NetWithLoss(Net())) 208 context.set_auto_parallel_context(parallel_mode="auto_parallel") 209 net.set_auto_parallel() 210 x1 = Tensor(np.ones([32, 32]), dtype=ms.float32) 211 x6 = Tensor(np.ones([128, 32]), dtype=ms.float32) 212 net.set_train() 213 _cell_graph_executor.compile(net, x1, x6) 214 215 216# remove matmul2, add matmul7 217def test_six_matmul_load_auto(): 218 class NetWithLoss(nn.Cell): 219 def __init__(self, network): 220 super(NetWithLoss, self).__init__() 221 self.loss = VirtualLoss() 222 self.network = network 223 224 def construct(self, x1, x6, x7): 225 predict = self.network(x1, x6, x7) 226 return self.loss(predict) 227 228 class GradWrap(nn.Cell): 229 def __init__(self, network): 230 super(GradWrap, self).__init__() 231 self.network = network 232 233 def construct(self, x1, x6, x7): 234 return grad_all(self.network)(x1, x6, x7) 235 236 class Net(nn.Cell): 237 def __init__(self, strategy1, strategy3, strategy4, strategy5): 238 super().__init__() 239 self.matmul1 = P.MatMul().shard(strategy1) 240 self.matmul3 = P.MatMul().shard(strategy3) 241 self.matmul4 = P.MatMul().shard(strategy4) 242 self.matmul5 = P.MatMul().shard(strategy5) 243 self.matmul6 = P.MatMul() 244 self.matmul7 = P.MatMul() 245 self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1") 246 self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3") 247 self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4") 248 self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5") 249 self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6") 250 251 def construct(self, x1, x6, x7): 252 out = self.matmul1(x1, self.weight1) 253 out = self.matmul3(out, self.weight3) 254 out = self.matmul4(out, self.weight4) 255 out = self.matmul5(out, self.weight5) 256 out = out + self.weight6 257 out = self.matmul6(out, x6) 258 out = self.matmul7(out, x7) 259 return out 260 261 reset_auto_parallel_context() 262 set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1_auto.ckpt") 263 strategy1 = ((2, 2), (2, 2)) 264 strategy3 = ((2, 2), (2, 2)) 265 strategy4 = ((2, 2), (2, 2)) 266 strategy5 = ((2, 2), (2, 2)) 267 net = GradWrap(NetWithLoss(Net(strategy1, strategy3, strategy4, strategy5))) 268 context.set_auto_parallel_context(parallel_mode="auto_parallel") 269 net.set_auto_parallel() 270 x1 = Tensor(np.ones([32, 32]), dtype=ms.float32) 271 x6 = Tensor(np.ones([128, 32]), dtype=ms.float32) 272 x7 = Tensor(np.ones([32, 32]), dtype=ms.float32) 273 net.set_train() 274 _cell_graph_executor.compile(net, x1, x6, x7) 275