1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test view""" 16import pytest 17 18import mindspore as ms 19import mindspore.nn as nn 20import mindspore.common.initializer as init 21from mindspore import Tensor 22from mindspore import context 23 24context.set_context(mode=context.GRAPH_MODE) 25 26 27def test_view(): 28 class Net(nn.Cell): 29 def __init__(self): 30 super(Net, self).__init__() 31 self.value = Tensor([[1, 2, 3], [4, 5, 6]]) 32 33 def construct(self): 34 return self.value.view(-1) 35 36 net = Net() 37 net() 38 39 40def test_view_initializer(): 41 class Net(nn.Cell): 42 def __init__(self): 43 super(Net, self).__init__() 44 self.value = init.initializer('normal', [2, 3], ms.float32) 45 46 def construct(self): 47 return self.value.view(-1) 48 49 net = Net() 50 net() 51 52 53def test_view_1(): 54 class Net(nn.Cell): 55 def __init__(self): 56 super(Net, self).__init__() 57 self.value = Tensor([[1, 2, 3], [4, 5, 6]]) 58 59 def construct(self): 60 return self.value.view((3, 2)) 61 62 net = Net() 63 net() 64 65 66def test_view_2(): 67 class Net(nn.Cell): 68 def __init__(self): 69 super(Net, self).__init__() 70 self.value = Tensor([[1, 2, 3], [4, 5, 6]]) 71 72 def construct(self): 73 return self.value.view(3, 2) 74 75 net = Net() 76 net() 77 78 79def test_view_parameter(): 80 class Net(nn.Cell): 81 def __init__(self): 82 super(Net, self).__init__() 83 84 def construct(self, x): 85 return x.view(-1) 86 87 net = Net() 88 net(Tensor([[1, 2, 3], [4, 5, 6]])) 89 90 91def test_view_parameter_1(): 92 class Net(nn.Cell): 93 def __init__(self): 94 super(Net, self).__init__() 95 96 def construct(self, x): 97 return x.view((3, 2)) 98 99 net = Net() 100 net(Tensor([[1, 2, 3], [4, 5, 6]])) 101 102 103def test_view_parameter_2(): 104 class Net(nn.Cell): 105 def __init__(self): 106 super(Net, self).__init__() 107 108 def construct(self, x): 109 return x.view(3, 2) 110 111 net = Net() 112 net(Tensor([[1, 2, 3], [4, 5, 6]])) 113 114 115def test_view_shape_error(): 116 class Net(nn.Cell): 117 def __init__(self): 118 super(Net, self).__init__() 119 self.value = Tensor([[1, 2, 3], [4, 5, 6]]) 120 121 def construct(self): 122 return self.value.view() 123 124 net = Net() 125 with pytest.raises(ValueError) as ex: 126 net() 127 assert "The shape variable should not be empty" in str(ex.value) 128 129 130def test_view_shape_error_1(): 131 class Net(nn.Cell): 132 def __init__(self): 133 super(Net, self).__init__() 134 self.value = Tensor([[1, 2, 3], [4, 5, 6]]) 135 136 def construct(self): 137 return self.value.view((2, 3), (4, 5)) 138 139 net = Net() 140 with pytest.raises(ValueError) as ex: 141 net() 142 assert "Only one tuple is needed" in str(ex.value) 143