1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16 17import mindspore as ms 18from mindspore import context, Tensor, Parameter 19from mindspore.common.api import _cell_graph_executor 20from mindspore.nn import Cell, TrainOneStepCell, Momentum 21from mindspore.ops import operations as P 22 23 24class Net(Cell): 25 def __init__(self, weight, weight2, strategy1=None, strategy2=None, is_parameter=True): 26 super().__init__() 27 self.mul = P.Mul().shard(strategy1) 28 self.tile = P.Tile().shard(strategy2) 29 if is_parameter: 30 self.weight = Parameter(weight, "w1") 31 else: 32 self.weight = weight 33 self.mul2 = P.Mul() 34 self.weight2 = Parameter(weight2, "w2") 35 36 def construct(self, x, b): 37 out = self.tile(self.weight, (8, 4, 2)) 38 out = self.mul(x, out) 39 out = self.mul2(out, self.weight2) 40 return out 41 42 43class Net2(Cell): 44 def __init__(self, weight2, strategy1=None, strategy2=None): 45 super().__init__() 46 self.mul = P.Mul().shard(strategy1) 47 self.tile = P.Tile().shard(strategy2) 48 self.weight2 = Parameter(weight2, "w2") 49 50 def construct(self, x, b): 51 out = self.mul(x, self.weight2) 52 out = self.tile(out, (8, 8, 4, 2)) 53 return out 54 55class Net3(Cell): 56 def __init__(self, weight, strategy1=None, strategy2=None, is_parameter=True): 57 super().__init__() 58 self.mul = P.Mul().shard(strategy1) 59 self.tile = P.Tile().shard(strategy2) 60 if is_parameter: 61 self.weight = Parameter(weight, "w1") 62 else: 63 self.weight = weight 64 self.mul2 = P.Mul() 65 66 def construct(self, x, b): 67 out = self.tile(self.weight, (8, 1, 1)) 68 out = self.mul(x, out) 69 return out 70 71 72_x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) 73_x1 = Tensor(np.ones([128, 16, 16]), dtype=ms.float32) 74_w1 = Tensor(np.ones([16, 16, 16]), dtype=ms.float32) 75_w2 = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) 76_w3 = Tensor(np.ones([128, 16, 16]), dtype=ms.float32) 77_b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) 78 79 80def compile_net(net, x=_b, b=_b): 81 optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 82 train_net = TrainOneStepCell(net, optimizer) 83 train_net.set_auto_parallel() 84 train_net.set_train() 85 _cell_graph_executor.compile(train_net, x, b) 86 context.reset_auto_parallel_context() 87 88 89def test_tile_parameter(): 90 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 91 strategy1 = ((2, 2, 2), (2, 2, 2)) 92 strategy2 = ((2, 2, 2),) 93 net = Net(_w1, _w2, strategy1, strategy2, is_parameter=True) 94 compile_net(net) 95 96 97def test_tile_parameter_no_full_split(): 98 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 99 strategy1 = ((2, 2, 2), (2, 2, 2)) 100 strategy2 = ((2, 2, 1),) 101 net = Net(_w1, _w2, strategy1, strategy2, is_parameter=True) 102 compile_net(net) 103 104 105def test_tile_tensor(): 106 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 107 strategy1 = ((2, 2, 2), (2, 2, 2)) 108 strategy2 = ((2, 2, 2),) 109 net = Net(_w1, _w2, strategy1, strategy2, is_parameter=False) 110 compile_net(net) 111 112 113def test_tile_tensor_no_full_split(): 114 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 115 strategy1 = ((2, 2, 2), (2, 2, 2)) 116 strategy2 = ((2, 2, 1),) 117 net = Net(_w1, _w2, strategy1, strategy2, is_parameter=False) 118 compile_net(net) 119 120 121def test_tile_tensor_no_full_split2(): 122 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 123 strategy1 = ((2, 2, 1), (2, 2, 1)) 124 strategy2 = ((2, 2, 1),) 125 net = Net3(_w1, strategy1, strategy2) 126 compile_net(net, _x1, _b) 127 128 129def test_tile_output(): 130 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 131 strategy1 = ((2, 2, 2), (2, 2, 2)) 132 strategy2 = ((1, 2, 2, 2),) 133 net = Net2(_w2, strategy1, strategy2) 134 compile_net(net) 135 136 137def test_tile_output_no_full_split(): 138 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 139 strategy1 = ((2, 2, 2), (2, 2, 2)) 140 strategy2 = ((1, 2, 1, 2),) 141 net = Net2(_w2, strategy1, strategy2) 142 compile_net(net) 143 144 145def test_tile_no_strategy(): 146 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 147 strategy1 = ((2, 2, 2), (2, 2, 2)) 148 strategy2 = None 149 net = Net2(_w2, strategy1, strategy2) 150 compile_net(net) 151 152 153def test_tile_auto_parallel(): 154 context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) 155 net = Net2(_w2) 156 compile_net(net) 157 158 159def test_tile_auto_parallel_2(): 160 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 161 net = Net3(_w1) 162 compile_net(net, _x1, _b) 163