• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test distribute predict """
16import numpy as np
17import pytest
18import mindspore.nn as nn
19from mindspore import Tensor, Model
20from mindspore.ops import operations as P
21from mindspore import context
22from mindspore.parallel._utils import _infer_rank_list
23
24
25class Net(nn.Cell):
26    """Net definition"""
27    def __init__(self):
28        super(Net, self).__init__()
29        self.fc1 = nn.Dense(128, 768, activation='relu')
30        self.fc2 = nn.Dense(128, 768, activation='relu')
31        self.fc3 = nn.Dense(128, 768, activation='relu')
32        self.fc4 = nn.Dense(768, 768, activation='relu')
33        self.relu4 = nn.ReLU()
34        self.relu5 = nn.ReLU()
35        self.transpose = P.Transpose()
36        self.matmul1 = P.MatMul()
37        self.matmul2 = P.MatMul()
38
39    def construct(self, x):
40        q = self.fc1(x)
41        k = self.fc2(x)
42        v = self.fc3(x)
43        k = self.transpose(k, (1, 0))
44        c = self.relu4(self.matmul1(q, k))
45        s = self.relu5(self.matmul2(c, v))
46        s = self.fc4(s)
47        return s
48
49
50def test_distribute_predict():
51    context.set_context(mode=context.GRAPH_MODE)
52    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True,
53                                      enable_parallel_optimizer=True)
54    inputs = Tensor(np.ones([32, 128]).astype(np.float32))
55    net = Net()
56    model = Model(net)
57    predict_map = model.infer_predict_layout(inputs)
58    output = model.predict(inputs)
59    context.reset_auto_parallel_context()
60    return predict_map, output
61
62
63def test_edge_case():
64    context.set_context(mode=context.GRAPH_MODE)
65    inputs = Tensor(np.ones([32, 48]).astype(np.float32))
66    net = Net()
67    model = Model(net)
68    with pytest.raises(RuntimeError):
69        model.infer_predict_layout(inputs)
70    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
71    with pytest.raises(RuntimeError):
72        model.infer_predict_layout(inputs)
73
74
75# standalone predict
76def test_infer_rank_list1():
77    train_map = {'weight': [[4, 8], [-1, 0]]}
78    predict_map = None
79    rank_list = _infer_rank_list(train_map, predict_map)["weight"]
80    assert list(rank_list[0]) == [0, 1, 2, 3, 4, 5, 6, 7]
81    assert rank_list[1] is False
82
83
84# similar layout: gpt3 prediction mode
85def test_infer_rank_list2():
86    train_map = {'weight': [[4, 8], [-1, 0]]}
87    predict_map = {'weight': [[8], [-1, 0]]}
88    rank_list = _infer_rank_list(train_map, predict_map)
89    expect_map = {'weight': ([0], True)}
90    assert rank_list == expect_map
91
92
93# same layout
94def test_infer_rank_list3():
95    train_map = {'weight': [[4, 8], [-1, 0]]}
96    predict_map = {'weight': [[4, 8], [-1, 0]]}
97    rank_list = _infer_rank_list(train_map, predict_map)
98    expect_map = {'weight': ([0], True)}
99    assert rank_list == expect_map
100
101
102# totally different layout
103def test_infer_rank_list4():
104    train_map = {'weight': [[4, 8], [-1, 0]]}
105    predict_map = {'weight': [[2, 2], [1, 0]]}
106    rank_list = _infer_rank_list(train_map, predict_map)["weight"]
107    assert list(rank_list[0]) == [0, 1, 2, 3, 4, 5, 6, 7]
108    assert rank_list[1] is False
109
110
111# full shape ckpt
112def test_infer_rank_list5():
113    train_map = {'weight': [[8], [-1, -1]]}
114    predict_map = {'weight': [[2, 2], [1, 0]]}
115    rank_list = _infer_rank_list(train_map, predict_map)
116    expect_map = {'weight': ([0], False)}
117    assert rank_list == expect_map
118