• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numpy as np
16import pytest
17from mindspore import context
18import mindspore.nn as nn
19from mindspore.ops import operations as P
20from mindspore import Tensor, Parameter
21import mindspore as ms
22import mindspore.common.api as me
23from mindspore.common.initializer import initializer
24from mindspore.common import set_seed
25from hccl_test.manage.api import Hccl
26
27class Net(nn.Cell):
28    def __init__(self, strategy1, strategy2, weight):
29        super().__init__()
30        self.weight = Parameter(weight, "w1")
31        self.matmul = P.MatMul(transpose_a=False, transpose_b=True).shard(strategy1)
32        self.relu = P.ReLU().shard(strategy2)
33
34    def construct(self, x):
35        out = self.matmul(x, self.weight)
36        out = self.relu(out)
37        return out
38
39def check_initializer_weight_slice(init_name="Uniform"):
40    def get_slice(rank):
41        hccl = Hccl()
42        rank_save = hccl.rank_id
43        hccl.rank_id = rank
44        context.reset_auto_parallel_context()
45        context.set_auto_parallel_context(device_num=8, global_rank=0)
46        context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
47        strategy1 = ((2, 1), (4, 1))
48        strategy2 = ((2, 4),)
49        context.set_context(mode=context.GRAPH_MODE)
50        exe = me._cell_graph_executor
51
52        x = Tensor(np.ones([32, 32]), dtype=ms.float32)
53        weight = initializer(init_name, [64, 32], ms.float32)
54        net = Net(strategy1, strategy2, weight)
55        net.set_auto_parallel()
56        net.set_train()
57        exe.compile(net, x, auto_parallel_mode=True, phase='train')
58        hccl.rank_id = rank_save
59        return net.parameters_dict()['w1'].data.asnumpy()
60
61    slice0 = get_slice(0)
62    slice1 = get_slice(1)
63    slice4 = get_slice(4)
64    slice_shape = slice0.shape
65
66    slice0 = slice0.flatten()
67    slice1 = slice1.flatten()
68    slice4 = slice4.flatten()
69    expect_slice_shape = (16, 32)
70
71    assert expect_slice_shape == slice_shape
72    assert all(slice0 == slice4)
73    if init_name not in ["One", "Zero"]:
74        assert any(slice0 != slice1)
75
76initializers = ["Uniform", "Normal", "TruncatedNormal", "HeUniform", "HeNormal", "XavierUniform", "One", "Zero"]
77
78def test_initializer_weight_slice():
79    for init_name in initializers:
80        check_initializer_weight_slice(init_name)
81
82def test_wrong_order_set_parallel_mode_with_initializer():
83    weight = initializer("Normal", [64, 32], ms.float32)
84    strategy1 = ((2, 1), (4, 1))
85    strategy2 = ((2, 4),)
86    net = Net(strategy1, strategy2, weight)
87    exe = me._cell_graph_executor
88    x = Tensor(np.ones([32, 32]), dtype=ms.float32)
89    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
90    net.set_auto_parallel()
91    with pytest.raises(RuntimeError):
92        exe.compile(net, x, auto_parallel_mode=True, phase='train')
93
94def test_wrong_order_set_same_parallel_mode_with_initializer():
95    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
96    weight = initializer("Normal", [64, 32], ms.float32)
97    strategy1 = ((2, 1), (4, 1))
98    strategy2 = ((2, 4),)
99    net = Net(strategy1, strategy2, weight)
100    exe = me._cell_graph_executor
101    x = Tensor(np.ones([32, 32]), dtype=ms.float32)
102    context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
103    net.set_auto_parallel()
104    exe.compile(net, x, auto_parallel_mode=True, phase='train')
105
106def test_wrong_order_set_parallel_mode_without_initializer():
107    weight = Tensor(np.ones([64, 32]), ms.float32)
108    strategy1 = ((2, 1), (4, 1))
109    strategy2 = ((2, 4),)
110    net = Net(strategy1, strategy2, weight)
111    exe = me._cell_graph_executor
112    x = Tensor(np.ones([32, 32]), dtype=ms.float32)
113    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
114    net.set_auto_parallel()
115    exe.compile(net, x, auto_parallel_mode=True, phase='train')
116
117def test_check_initializer_weight_slice_seed(init_name="Uniform"):
118    def get_slice(rank):
119        set_seed(1)
120        hccl = Hccl()
121        rank_save = hccl.rank_id
122        hccl.rank_id = rank
123        context.reset_auto_parallel_context()
124        context.set_auto_parallel_context(device_num=8, global_rank=0)
125        context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
126        strategy1 = ((2, 1), (4, 1))
127        strategy2 = ((2, 4),)
128        context.set_context(mode=context.GRAPH_MODE)
129        exe = me._cell_graph_executor
130
131        x = Tensor(np.ones([32, 32]), dtype=ms.float32)
132        weight = initializer(init_name, [64, 32], ms.float32)
133        net = Net(strategy1, strategy2, weight)
134        net.set_auto_parallel()
135        net.set_train()
136        exe.compile(net, x, auto_parallel_mode=True, phase='train')
137        hccl.rank_id = rank_save
138        return net.parameters_dict()['w1'].data.asnumpy()
139
140
141    slice0 = get_slice(0)
142    slice1 = get_slice(1)
143    slice4 = get_slice(4)
144    slice_shape = slice0.shape
145
146    slice0 = slice0.flatten()
147    slice1 = slice1.flatten()
148    slice4 = slice4.flatten()
149    expect_slice_shape = (16, 32)
150
151    assert expect_slice_shape == slice_shape
152    assert all(slice0 == slice4)
153    assert all(slice0 == slice1)
154
155if __name__ == '__main__':
156    test_initializer_weight_slice()
157