1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16import mindspore as ms 17import mindspore.context as context 18from mindspore import Tensor, Parameter 19import mindspore.nn as nn 20from mindspore.common.api import _cell_graph_executor 21from mindspore.nn import TrainOneStepCell, Momentum 22from mindspore.ops import operations as P 23from mindspore.nn import Dense, Flatten 24 25 26class Net(nn.Cell): 27 def __init__(self, weight1, weight2, axis=0, strategy1=None, strategy2=None, is_parameter=True): 28 super(Net, self).__init__() 29 self.pack = P.Stack(axis=axis).shard(strategy1) 30 self.mul = P.Mul().shard(strategy2) 31 if is_parameter: 32 self.weight1 = Parameter(weight1, "w1") 33 else: 34 self.weight1 = weight1 35 self.weight2 = Parameter(weight2, "w2") 36 37 def construct(self, x): 38 out = self.pack([self.weight1, self.weight2]) 39 out = self.mul(x, out) 40 return out 41 42 43class Net1(nn.Cell): 44 def __init__(self, weight1, weight2, axis=0, strategy1=None, strategy2=None): 45 super(Net1, self).__init__() 46 self.pack = P.Stack(axis=axis).shard(strategy1) 47 self.mul = P.Mul().shard(strategy2) 48 self.weight1 = Parameter(weight1, "w1") 49 self.weight2 = Parameter(weight2, "w2") 50 51 def construct(self, x): 52 out = self.mul(x, self.weight1) 53 out = self.pack([out, self.weight2]) 54 return out 55 56 57class Net2(nn.Cell): 58 def __init__(self, weight1, weight2, weight3, axis=0, strategy1=None, strategy2=None, is_parameter=True): 59 super(Net2, self).__init__() 60 self.pack = P.Stack(axis=axis).shard(strategy1) 61 self.mul = P.Mul().shard(strategy2) 62 if is_parameter: 63 self.weight1 = Parameter(weight1, "w1") 64 else: 65 self.weight1 = weight1 66 self.weight2 = Parameter(weight2, "w2") 67 self.weight3 = Parameter(weight2, "w3") 68 69 def construct(self, x): 70 out = self.pack([self.weight1, self.weight2, self.weight3]) 71 out = self.mul(x, out) 72 return out 73 74 75class PackConstantNet1(nn.Cell): 76 def __init__(self, dense_in_channel, dense_out_channel, axis=0, shape=None, strategy=None): 77 super().__init__() 78 weight_np = np.full((dense_out_channel, dense_in_channel), 0.01, dtype=np.float32) 79 bias_np = np.full((dense_out_channel), 0.01, dtype=np.float32) 80 self.pack_con = Tensor(np.full(shape, 0.01, dtype=np.float32)) 81 self.flat = Flatten() 82 self.dense = Dense(in_channels=dense_in_channel, 83 out_channels=dense_out_channel, 84 weight_init=Tensor(weight_np), 85 bias_init=Tensor(bias_np), 86 has_bias=True) 87 self.mul = P.Mul() 88 self.pack = P.Stack(axis) 89 if strategy is not None: 90 self.pack.shard(strategy) 91 92 def construct(self, inputs): 93 x = self.pack([self.pack_con, self.pack_con, self.pack_con, self.pack_con, 94 self.pack_con, self.pack_con, self.pack_con, self.pack_con]) 95 x1 = self.flat(x) 96 x2 = self.flat(inputs) 97 x = self.mul(x1, x2) 98 x = self.dense(x) 99 return x 100 101 102class PackConstantNet2(nn.Cell): 103 def __init__(self, dense_in_channel, dense_out_channel, axis=0, shape=None, strategy=None): 104 super().__init__() 105 weight_np = np.full((dense_out_channel, dense_in_channel), 0.01, dtype=np.float32) 106 bias_np = np.full((dense_out_channel), 0.01, dtype=np.float32) 107 self.pack_con = Tensor(np.full(shape, 0.01, dtype=np.float32)) 108 self.flat = Flatten() 109 self.dense = Dense(in_channels=dense_in_channel, 110 out_channels=dense_out_channel, 111 weight_init=Tensor(weight_np), 112 bias_init=Tensor(bias_np), 113 has_bias=True) 114 self.mul = P.Mul() 115 self.pack = P.Stack(axis) 116 if strategy is not None: 117 self.pack.shard(strategy) 118 119 def construct(self, inputs): 120 x = self.pack((self.pack_con, self.pack_con, self.pack_con, self.pack_con, 121 self.pack_con, self.pack_con, self.pack_con, self.pack_con)) 122 x1 = self.flat(x) 123 x2 = self.flat(inputs) 124 x = self.mul(x1, x2) 125 x = self.dense(x) 126 return x 127 128 129_w1 = Tensor(np.ones([48, 64]), dtype=ms.float32) 130_w2 = Tensor(np.ones([48, 64]), dtype=ms.float32) 131_w3 = Tensor(np.ones([48, 64]), dtype=ms.float32) 132_x = Tensor(np.ones([2, 48, 64]), dtype=ms.float32) 133_x1 = Tensor(np.ones([48, 64]), dtype=ms.float32) 134_x2 = Tensor(np.ones([3, 48, 64]), dtype=ms.float32) 135_x_c = Tensor(np.ones([8, 8, 8]), dtype=ms.float32) 136 137 138def compile_net(net): 139 context.set_context(mode=context.GRAPH_MODE) 140 optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 141 train_net = TrainOneStepCell(net, optimizer) 142 train_net.set_auto_parallel() 143 train_net.set_train() 144 _cell_graph_executor.compile(train_net, _x) 145 context.reset_auto_parallel_context() 146 147 148def compile_net1(net): 149 context.set_context(mode=context.GRAPH_MODE) 150 optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 151 train_net = TrainOneStepCell(net, optimizer) 152 train_net.set_auto_parallel() 153 train_net.set_train() 154 _cell_graph_executor.compile(train_net, _x1) 155 context.reset_auto_parallel_context() 156 157 158def compile_net2(net): 159 context.set_context(mode=context.GRAPH_MODE) 160 optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 161 train_net = TrainOneStepCell(net, optimizer) 162 train_net.set_auto_parallel() 163 train_net.set_train() 164 _cell_graph_executor.compile(train_net, _x2) 165 context.reset_auto_parallel_context() 166 167 168def compile_net_con(net): 169 context.set_context(mode=context.GRAPH_MODE) 170 optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 171 train_net = TrainOneStepCell(net, optimizer) 172 train_net.set_auto_parallel() 173 _cell_graph_executor.compile(train_net, _x_c) 174 context.reset_auto_parallel_context() 175 176 177def test_pack_parameter(): 178 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 179 strategy1 = ((4, 2), (4, 2)) 180 strategy2 = ((1, 4, 2), (1, 4, 2)) 181 net = Net(_w1, _w2, 0, strategy1, strategy2) 182 compile_net(net) 183 184 185def test_pack_parameter_no_full_split(): 186 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 187 strategy1 = ((2, 2), (2, 2)) 188 strategy2 = ((1, 4, 2), (1, 4, 2)) 189 net = Net(_w1, _w2, 0, strategy1, strategy2) 190 compile_net(net) 191 192 193def test_pack_tensor_and_parameter(): 194 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 195 strategy1 = ((4, 2), (4, 2)) 196 strategy2 = ((1, 4, 2), (1, 4, 2)) 197 net = Net(_w1, _w2, 0, strategy1, strategy2, False) 198 compile_net(net) 199 200 201def test_pack_output(): 202 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 203 strategy1 = ((4, 2), (4, 2)) 204 strategy2 = ((4, 2), (4, 2)) 205 net = Net1(_w1, _w2, 0, strategy1, strategy2) 206 compile_net1(net) 207 208 209def test_pack_output_axis1(): 210 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 211 strategy1 = ((4, 2), (4, 2)) 212 strategy2 = ((4, 2), (4, 2)) 213 net = Net1(_w1, _w2, 1, strategy1, strategy2) 214 compile_net1(net) 215 216 217def test_pack_output_no_full_split(): 218 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 219 strategy1 = ((2, 2), (2, 2)) 220 strategy2 = ((4, 2), (4, 2)) 221 net = Net1(_w1, _w2, 0, strategy1, strategy2) 222 compile_net1(net) 223 224 225def test_pack_no_strategy(): 226 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 227 strategy1 = None 228 strategy2 = ((4, 2), (4, 2)) 229 net = Net1(_w1, _w2, 0, strategy1, strategy2) 230 compile_net1(net) 231 232 233def test_pack_no_strategy_axis1(): 234 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 235 strategy1 = None 236 strategy2 = ((4, 2), (4, 2)) 237 net = Net1(_w1, _w2, 1, strategy1, strategy2) 238 compile_net1(net) 239 240 241def test_pack_auto_parallel(): 242 context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) 243 net = Net1(_w1, _w2, 0) 244 compile_net1(net) 245 246 247def test_pack_auto_parallel_axis1(): 248 context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) 249 net = Net1(_w1, _w2, 1) 250 compile_net1(net) 251 252 253def test_pack_auto_parallel_3_tensor(): 254 context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) 255 net = Net2(_w1, _w2, _w3) 256 compile_net2(net) 257 258 259def test_pack_constant1(): 260 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 261 net = PackConstantNet1(dense_in_channel=64, dense_out_channel=4, axis=0, shape=(8, 8), 262 strategy=((4, 1), (4, 1), (4, 1), (4, 1), (4, 1), (4, 1), (4, 1), (4, 1))) 263 compile_net_con(net) 264 265 266def test_pack_constant2(): 267 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 268 net = PackConstantNet2(dense_in_channel=64, dense_out_channel=4, axis=0, shape=(8, 8), 269 strategy=((4, 1), (4, 1), (4, 1), (4, 1), (4, 1), (4, 1), (4, 1), (4, 1))) 270 compile_net_con(net) 271 272 273def test_pack_auto_constant(): 274 context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) 275 net = PackConstantNet1(dense_in_channel=64, dense_out_channel=4, axis=0, shape=(8, 8), 276 strategy=((8, 1), (8, 1), (8, 1), (8, 1), (8, 1), (8, 1), (8, 1), (8, 1))) 277 compile_net_con(net) 278