1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16 17import mindspore as ms 18from mindspore import context, Tensor 19from mindspore.common.api import _cell_graph_executor 20from mindspore.nn import Cell 21from mindspore.ops import operations as P 22 23 24class Net(Cell): 25 def __init__(self, strategy1=None, strategy2=None, axis=()): 26 super().__init__() 27 self.squeeze = P.Squeeze(axis=axis).shard(strategy1) 28 self.mul = P.Mul().shard(strategy2) 29 30 def construct(self, x, b): 31 out = self.squeeze(x) 32 out = self.mul(out, b) 33 return out 34 35 36_x = Tensor(np.ones([64, 1, 32, 1]), dtype=ms.float32) 37_b = Tensor(np.ones([64, 32]), dtype=ms.float32) 38 39 40def compile_net(net): 41 net.set_auto_parallel() 42 net.set_train() 43 _cell_graph_executor.compile(net, _x, _b) 44 context.reset_auto_parallel_context() 45 46 47def test_squeeze_data_parallel(): 48 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) 49 strategy1 = ((16, 1, 1, 1),) 50 strategy2 = ((16, 1), (16, 1)) 51 net = Net(strategy1, strategy2) 52 compile_net(net) 53 54 55def test_squeeze_model_parallel(): 56 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) 57 strategy1 = ((1, 1, 16, 1),) 58 strategy2 = ((1, 16), (1, 16)) 59 net = Net(strategy1, strategy2) 60 compile_net(net) 61 62 63def test_squeeze_specified_axis(): 64 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) 65 strategy1 = ((4, 1, 4, 1),) 66 strategy2 = ((8, 2), (8, 2)) 67 net = Net(strategy1, strategy2, (1, 3)) 68 compile_net(net) 69 70 71def test_squeeze_auto_parallel(): 72 context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0) 73 net = Net() 74 compile_net(net) 75 76 77def test_squeeze_repeat_calc(): 78 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) 79 strategy1 = ((1, 1, 8, 1),) 80 strategy2 = ((2, 8), (2, 8)) 81 net = Net(strategy1, strategy2) 82 compile_net(net) 83