1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16 17import mindspore as ms 18import mindspore.nn as nn 19from mindspore import Tensor 20from mindspore import context 21from mindspore.common.api import _cell_graph_executor 22from mindspore.common.parameter import Parameter 23from mindspore.ops import composite as C 24from mindspore.ops import operations as P 25from tests.ut.python.ops.test_math_ops import VirtualLoss 26 27 28grad_all = C.GradOperation(get_all=True) 29 30 31class NetWithLoss(nn.Cell): 32 def __init__(self, network): 33 super(NetWithLoss, self).__init__() 34 self.loss = VirtualLoss() 35 self.network = network 36 37 def construct(self, x): 38 predict = self.network(x) 39 return self.loss(predict) 40 41 42class GradWrap(nn.Cell): 43 def __init__(self, network): 44 super(GradWrap, self).__init__() 45 self.network = network 46 47 def construct(self, x): 48 return grad_all(self.network)(x) 49 50def test_reshape_unexpand(): 51 class Net(nn.Cell): 52 def __init__(self): 53 super().__init__() 54 self.reshape = P.Reshape() 55 self.mul = P.Mul().shard(((1, 8), (1, 1, 8))) 56 self.mul_weight = Parameter(Tensor(np.ones([96, 128]), dtype=ms.float32), name="weight") 57 58 def construct(self, x): 59 weight = self.reshape(self.mul_weight, (1, 128, 96)) 60 out = self.mul(x, weight) 61 return out 62 63 size = 8 64 context.set_auto_parallel_context(device_num=size, global_rank=0) 65 x = Tensor(np.ones([128, 96]), dtype=ms.float32) 66 67 net = GradWrap(NetWithLoss(Net())) 68 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 69 net.set_auto_parallel() 70 net.set_train() 71 _cell_graph_executor.compile(net, x) 72 73def test_reshape_unexpand_1(): 74 class Net(nn.Cell): 75 def __init__(self): 76 super().__init__() 77 self.reshape = P.Reshape() 78 self.mul = P.Mul().shard(((1, 1, 8), (1, 8))) 79 self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight") 80 81 def construct(self, data): 82 x = self.reshape(self.mul_weight, (1, 128, 96)) 83 out = self.mul(x, self.mul_weight) 84 return out 85 86 size = 8 87 context.set_auto_parallel_context(device_num=size, global_rank=0) 88 x = Tensor(np.ones([128, 96]), dtype=ms.float32) 89 90 net = GradWrap(NetWithLoss(Net())) 91 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 92 net.set_auto_parallel() 93 net.set_train() 94 _cell_graph_executor.compile(net, x) 95 96def test_reshape_unexpand_2(): 97 class Net(nn.Cell): 98 def __init__(self): 99 super().__init__() 100 self.reshape = P.Reshape() 101 self.mul = P.Mul().shard(((1, 4, 2), (4, 2))) 102 self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight") 103 104 def construct(self, data): 105 x = self.reshape(self.mul_weight, (1, 128, 96)) 106 out = self.mul(x, self.mul_weight) 107 return out 108 109 size = 8 110 context.set_auto_parallel_context(device_num=size, global_rank=0) 111 x = Tensor(np.ones([128, 96]), dtype=ms.float32) 112 113 net = GradWrap(NetWithLoss(Net())) 114 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 115 net.set_auto_parallel() 116 net.set_train() 117 _cell_graph_executor.compile(net, x) 118 119def test_reshape_unexpand_3(): 120 class Net(nn.Cell): 121 def __init__(self): 122 super().__init__() 123 self.reshape = P.Reshape() 124 self.relu1 = P.ReLU().shard(((4, 1),)) 125 self.relu2 = P.ReLU().shard(((1, 4),)) 126 127 def construct(self, data): 128 x = self.relu1(data) 129 x = self.reshape(x, (3, 4)) 130 x = self.relu2(x) 131 return x 132 133 size = 4 134 context.set_auto_parallel_context(device_num=size, global_rank=0) 135 x = Tensor(np.ones([4, 3]), dtype=ms.float32) 136 137 net = GradWrap(NetWithLoss(Net())) 138 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 139 net.set_auto_parallel() 140 net.set_train() 141 _cell_graph_executor.compile(net, x) 142 143def test_reshape_unexpand_4(): 144 class Net(nn.Cell): 145 def __init__(self): 146 super().__init__() 147 self.reshape = P.Reshape() 148 self.relu1 = P.ReLU().shard(((4, 1),)) 149 self.relu2 = P.ReLU().shard(((1, 2, 2),)) 150 151 def construct(self, data): 152 x = self.relu1(data) 153 x = self.reshape(x, (3, 2, 2)) 154 x = self.relu2(x) 155 return x 156 157 size = 4 158 context.set_auto_parallel_context(device_num=size, global_rank=0) 159 x = Tensor(np.ones([4, 3]), dtype=ms.float32) 160 161 net = GradWrap(NetWithLoss(Net())) 162 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 163 net.set_auto_parallel() 164 net.set_train() 165 _cell_graph_executor.compile(net, x) 166 167def test_reshape_unexpand_5(): 168 class Net(nn.Cell): 169 def __init__(self): 170 super().__init__() 171 self.reshape = P.Reshape() 172 self.relu1 = P.ReLU().shard(((2, 2, 1),)) 173 self.relu2 = P.ReLU().shard(((1, 4),)) 174 175 def construct(self, data): 176 x = self.relu1(data) 177 x = self.reshape(x, (3, 4)) 178 x = self.relu2(x) 179 return x 180 181 size = 4 182 context.set_auto_parallel_context(device_num=size, global_rank=0) 183 x = Tensor(np.ones([2, 2, 3]), dtype=ms.float32) 184 185 net = GradWrap(NetWithLoss(Net())) 186 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 187 net.set_auto_parallel() 188 net.set_train() 189 _cell_graph_executor.compile(net, x) 190 191def test_reshape_unexpand_6(): 192 class Net(nn.Cell): 193 def __init__(self): 194 super().__init__() 195 self.reshape = P.Reshape() 196 self.relu1 = P.ReLU().shard(((2, 1),)) 197 self.relu2 = P.ReLU().shard(((1, 1, 4),)) 198 199 def construct(self, data): 200 x = self.relu1(data) 201 x = self.reshape(x, (1, 3, 4)) 202 x = self.relu2(x) 203 return x 204 205 size = 4 206 context.set_auto_parallel_context(device_num=size, global_rank=0) 207 x = Tensor(np.ones([4, 3]), dtype=ms.float32) 208 209 net = GradWrap(NetWithLoss(Net())) 210 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 211 net.set_auto_parallel() 212 net.set_train() 213 _cell_graph_executor.compile(net, x) 214 215def test_reshape_unexpand_7(): 216 class Net(nn.Cell): 217 def __init__(self, in_channel=3, out_channel=8, axis=1, input_shape=(32, 4, 110, -1), 218 mul_size=(32, 1, 220, 220)): 219 super().__init__() 220 mul_np = np.full(mul_size, 0.5, dtype=np.float32) 221 self.mul_weight = Parameter(Tensor(mul_np), name="mul_weight") 222 self.mul = P.Mul() 223 self.conv = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, 224 kernel_size=5, has_bias=True, weight_init='ones', 225 bias_init='ones', pad_mode='valid') 226 self.conv.conv2d.shard(((8, 1, 1, 1), (1, 1, 1, 1))) 227 self.softmax = nn.Softmax(axis=axis) 228 self.relu = nn.ReLU() 229 self.reshape = P.Reshape() 230 self.input_shape = input_shape 231 232 def construct(self, inputs): 233 x = self.conv(inputs) 234 x = self.softmax(x) 235 x = self.relu(x) 236 x = self.mul(x, self.mul_weight) 237 x = self.reshape(x, self.input_shape) 238 return x 239 240 size = 8 241 context.set_auto_parallel_context(device_num=size, global_rank=0) 242 context.set_auto_parallel_context(parallel_mode="auto_parallel") 243 x = Tensor(np.ones([32, 3, 224, 224]), dtype=ms.float32) 244 net = GradWrap(NetWithLoss(Net())) 245 net.set_auto_parallel() 246 net.set_train() 247 _cell_graph_executor.compile(net, x) 248 249def test_reshape_unexpand_8(): 250 class Net(nn.Cell): 251 def __init__(self): 252 super().__init__() 253 self.reshape = P.Reshape() 254 self.mul = P.Mul().shard(((1, 4, 2), (4, 2))) 255 self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight") 256 257 def construct(self, data): 258 x = self.reshape(self.mul_weight, (1, 128, 96)) 259 out = self.mul(x, self.mul_weight) 260 return out 261 262 size = 8 263 context.set_auto_parallel_context(device_num=size, global_rank=0) 264 x = Tensor(np.ones([128, 96]), dtype=ms.float32) 265 266 net = GradWrap(NetWithLoss(Net())) 267 context.set_auto_parallel_context(parallel_mode="auto_parallel") 268 net.set_auto_parallel() 269 net.set_train() 270 _cell_graph_executor.compile(net, x) 271