1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16import pytest 17from mindspore import context 18import mindspore.nn as nn 19from mindspore.ops import operations as P 20from mindspore import Tensor, Parameter 21import mindspore as ms 22import mindspore.common.api as me 23from mindspore.common.initializer import initializer 24from mindspore.common import set_seed 25from hccl_test.manage.api import Hccl 26 27class Net(nn.Cell): 28 def __init__(self, strategy1, strategy2, weight): 29 super().__init__() 30 self.weight = Parameter(weight, "w1") 31 self.matmul = P.MatMul(transpose_a=False, transpose_b=True).shard(strategy1) 32 self.relu = P.ReLU().shard(strategy2) 33 34 def construct(self, x): 35 out = self.matmul(x, self.weight) 36 out = self.relu(out) 37 return out 38 39def check_initializer_weight_slice(init_name="Uniform"): 40 def get_slice(rank): 41 hccl = Hccl() 42 rank_save = hccl.rank_id 43 hccl.rank_id = rank 44 context.reset_auto_parallel_context() 45 context.set_auto_parallel_context(device_num=8, global_rank=0) 46 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 47 strategy1 = ((2, 1), (4, 1)) 48 strategy2 = ((2, 4),) 49 context.set_context(mode=context.GRAPH_MODE) 50 exe = me._cell_graph_executor 51 52 x = Tensor(np.ones([32, 32]), dtype=ms.float32) 53 weight = initializer(init_name, [64, 32], ms.float32) 54 net = Net(strategy1, strategy2, weight) 55 net.set_auto_parallel() 56 net.set_train() 57 exe.compile(net, x, auto_parallel_mode=True, phase='train') 58 hccl.rank_id = rank_save 59 return net.parameters_dict()['w1'].data.asnumpy() 60 61 slice0 = get_slice(0) 62 slice1 = get_slice(1) 63 slice4 = get_slice(4) 64 slice_shape = slice0.shape 65 66 slice0 = slice0.flatten() 67 slice1 = slice1.flatten() 68 slice4 = slice4.flatten() 69 expect_slice_shape = (16, 32) 70 71 assert expect_slice_shape == slice_shape 72 assert all(slice0 == slice4) 73 if init_name not in ["One", "Zero"]: 74 assert any(slice0 != slice1) 75 76initializers = ["Uniform", "Normal", "TruncatedNormal", "HeUniform", "HeNormal", "XavierUniform", "One", "Zero"] 77 78def test_initializer_weight_slice(): 79 for init_name in initializers: 80 check_initializer_weight_slice(init_name) 81 82def test_wrong_order_set_parallel_mode_with_initializer(): 83 weight = initializer("Normal", [64, 32], ms.float32) 84 strategy1 = ((2, 1), (4, 1)) 85 strategy2 = ((2, 4),) 86 net = Net(strategy1, strategy2, weight) 87 exe = me._cell_graph_executor 88 x = Tensor(np.ones([32, 32]), dtype=ms.float32) 89 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 90 net.set_auto_parallel() 91 with pytest.raises(RuntimeError): 92 exe.compile(net, x, auto_parallel_mode=True, phase='train') 93 94def test_wrong_order_set_same_parallel_mode_with_initializer(): 95 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 96 weight = initializer("Normal", [64, 32], ms.float32) 97 strategy1 = ((2, 1), (4, 1)) 98 strategy2 = ((2, 4),) 99 net = Net(strategy1, strategy2, weight) 100 exe = me._cell_graph_executor 101 x = Tensor(np.ones([32, 32]), dtype=ms.float32) 102 context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) 103 net.set_auto_parallel() 104 exe.compile(net, x, auto_parallel_mode=True, phase='train') 105 106def test_wrong_order_set_parallel_mode_without_initializer(): 107 weight = Tensor(np.ones([64, 32]), ms.float32) 108 strategy1 = ((2, 1), (4, 1)) 109 strategy2 = ((2, 4),) 110 net = Net(strategy1, strategy2, weight) 111 exe = me._cell_graph_executor 112 x = Tensor(np.ones([32, 32]), dtype=ms.float32) 113 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 114 net.set_auto_parallel() 115 exe.compile(net, x, auto_parallel_mode=True, phase='train') 116 117def test_check_initializer_weight_slice_seed(init_name="Uniform"): 118 def get_slice(rank): 119 set_seed(1) 120 hccl = Hccl() 121 rank_save = hccl.rank_id 122 hccl.rank_id = rank 123 context.reset_auto_parallel_context() 124 context.set_auto_parallel_context(device_num=8, global_rank=0) 125 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 126 strategy1 = ((2, 1), (4, 1)) 127 strategy2 = ((2, 4),) 128 context.set_context(mode=context.GRAPH_MODE) 129 exe = me._cell_graph_executor 130 131 x = Tensor(np.ones([32, 32]), dtype=ms.float32) 132 weight = initializer(init_name, [64, 32], ms.float32) 133 net = Net(strategy1, strategy2, weight) 134 net.set_auto_parallel() 135 net.set_train() 136 exe.compile(net, x, auto_parallel_mode=True, phase='train') 137 hccl.rank_id = rank_save 138 return net.parameters_dict()['w1'].data.asnumpy() 139 140 141 slice0 = get_slice(0) 142 slice1 = get_slice(1) 143 slice4 = get_slice(4) 144 slice_shape = slice0.shape 145 146 slice0 = slice0.flatten() 147 slice1 = slice1.flatten() 148 slice4 = slice4.flatten() 149 expect_slice_shape = (16, 32) 150 151 assert expect_slice_shape == slice_shape 152 assert all(slice0 == slice4) 153 assert all(slice0 == slice1) 154 155if __name__ == '__main__': 156 test_initializer_weight_slice() 157