• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from mindspore import Tensor
16from mindspore.parallel._tensor import _reshape_param_data
17
18
19def test_reshape_param_data():
20    expected_tensor = Tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
21    dev_mat = [2, 2]
22    tensor_map = [0, 1]
23    input_tensor = Tensor([[1, 2], [5, 6], [3, 4], [7, 8]])
24    tensor = _reshape_param_data(input_tensor, dev_mat, tensor_map)
25    if expected_tensor.__str__() != tensor.__str__():
26        raise AssertionError
27
28    tensor_map = [1, -1]
29    input_tensor = Tensor([[1, 2, 3, 4], [1, 2, 3, 4], [5, 6, 7, 8], [5, 6, 7, 8]])
30    tensor = _reshape_param_data(input_tensor, dev_mat, tensor_map)
31    if expected_tensor.__str__() != tensor.__str__():
32        raise AssertionError
33
34    expected_tensor = Tensor([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], \
35                              [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]])
36
37    input_tensor = Tensor([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], \
38                           [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], \
39                           [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], \
40                           [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], \
41                           [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], \
42                           [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], \
43                           [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], \
44                           [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]])
45
46    dev_mat = [4]
47    tensor_map = [-1, -1, -1, -1]
48    tensor = _reshape_param_data(input_tensor, dev_mat, tensor_map)
49    if expected_tensor.__str__() != tensor.__str__():
50        raise AssertionError
51
52
53if __name__ == '__main__':
54    test_reshape_param_data()
55