1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from mindspore import Tensor 16from mindspore.parallel._tensor import _reshape_param_data 17 18 19def test_reshape_param_data(): 20 expected_tensor = Tensor([[1, 2, 3, 4], [5, 6, 7, 8]]) 21 dev_mat = [2, 2] 22 tensor_map = [0, 1] 23 input_tensor = Tensor([[1, 2], [5, 6], [3, 4], [7, 8]]) 24 tensor = _reshape_param_data(input_tensor, dev_mat, tensor_map) 25 if expected_tensor.__str__() != tensor.__str__(): 26 raise AssertionError 27 28 tensor_map = [1, -1] 29 input_tensor = Tensor([[1, 2, 3, 4], [1, 2, 3, 4], [5, 6, 7, 8], [5, 6, 7, 8]]) 30 tensor = _reshape_param_data(input_tensor, dev_mat, tensor_map) 31 if expected_tensor.__str__() != tensor.__str__(): 32 raise AssertionError 33 34 expected_tensor = Tensor([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], \ 35 [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]]) 36 37 input_tensor = Tensor([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], \ 38 [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], \ 39 [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], \ 40 [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], \ 41 [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], \ 42 [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], \ 43 [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], \ 44 [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]]) 45 46 dev_mat = [4] 47 tensor_map = [-1, -1, -1, -1] 48 tensor = _reshape_param_data(input_tensor, dev_mat, tensor_map) 49 if expected_tensor.__str__() != tensor.__str__(): 50 raise AssertionError 51 52 53if __name__ == '__main__': 54 test_reshape_param_data() 55