1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test distribute predict """ 16import numpy as np 17import pytest 18import mindspore.nn as nn 19from mindspore import Tensor, Model 20from mindspore.ops import operations as P 21from mindspore import context 22from mindspore.parallel._utils import _infer_rank_list 23 24 25class Net(nn.Cell): 26 """Net definition""" 27 def __init__(self): 28 super(Net, self).__init__() 29 self.fc1 = nn.Dense(128, 768, activation='relu') 30 self.fc2 = nn.Dense(128, 768, activation='relu') 31 self.fc3 = nn.Dense(128, 768, activation='relu') 32 self.fc4 = nn.Dense(768, 768, activation='relu') 33 self.relu4 = nn.ReLU() 34 self.relu5 = nn.ReLU() 35 self.transpose = P.Transpose() 36 self.matmul1 = P.MatMul() 37 self.matmul2 = P.MatMul() 38 39 def construct(self, x): 40 q = self.fc1(x) 41 k = self.fc2(x) 42 v = self.fc3(x) 43 k = self.transpose(k, (1, 0)) 44 c = self.relu4(self.matmul1(q, k)) 45 s = self.relu5(self.matmul2(c, v)) 46 s = self.fc4(s) 47 return s 48 49 50def test_distribute_predict(): 51 context.set_context(mode=context.GRAPH_MODE) 52 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True, 53 enable_parallel_optimizer=True) 54 inputs = Tensor(np.ones([32, 128]).astype(np.float32)) 55 net = Net() 56 model = Model(net) 57 predict_map = model.infer_predict_layout(inputs) 58 output = model.predict(inputs) 59 context.reset_auto_parallel_context() 60 return predict_map, output 61 62 63def test_edge_case(): 64 context.set_context(mode=context.GRAPH_MODE) 65 inputs = Tensor(np.ones([32, 48]).astype(np.float32)) 66 net = Net() 67 model = Model(net) 68 with pytest.raises(RuntimeError): 69 model.infer_predict_layout(inputs) 70 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 71 with pytest.raises(RuntimeError): 72 model.infer_predict_layout(inputs) 73 74 75# standalone predict 76def test_infer_rank_list1(): 77 train_map = {'weight': [[4, 8], [-1, 0]]} 78 predict_map = None 79 rank_list = _infer_rank_list(train_map, predict_map)["weight"] 80 assert list(rank_list[0]) == [0, 1, 2, 3, 4, 5, 6, 7] 81 assert rank_list[1] is False 82 83 84# similar layout: gpt3 prediction mode 85def test_infer_rank_list2(): 86 train_map = {'weight': [[4, 8], [-1, 0]]} 87 predict_map = {'weight': [[8], [-1, 0]]} 88 rank_list = _infer_rank_list(train_map, predict_map) 89 expect_map = {'weight': ([0], True)} 90 assert rank_list == expect_map 91 92 93# same layout 94def test_infer_rank_list3(): 95 train_map = {'weight': [[4, 8], [-1, 0]]} 96 predict_map = {'weight': [[4, 8], [-1, 0]]} 97 rank_list = _infer_rank_list(train_map, predict_map) 98 expect_map = {'weight': ([0], True)} 99 assert rank_list == expect_map 100 101 102# totally different layout 103def test_infer_rank_list4(): 104 train_map = {'weight': [[4, 8], [-1, 0]]} 105 predict_map = {'weight': [[2, 2], [1, 0]]} 106 rank_list = _infer_rank_list(train_map, predict_map)["weight"] 107 assert list(rank_list[0]) == [0, 1, 2, 3, 4, 5, 6, 7] 108 assert rank_list[1] is False 109 110 111# full shape ckpt 112def test_infer_rank_list5(): 113 train_map = {'weight': [[8], [-1, -1]]} 114 predict_map = {'weight': [[2, 2], [1, 0]]} 115 rank_list = _infer_rank_list(train_map, predict_map) 116 expect_map = {'weight': ([0], False)} 117 assert rank_list == expect_map 118