1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test partial""" 16from functools import partial 17 18import numpy as np 19import pytest 20 21from mindspore import nn, Tensor, context 22 23context.set_context(mode=context.GRAPH_MODE) 24 25def test_partial_pos_arg(): 26 class Net(nn.Cell): 27 def __init__(self): 28 super(Net, self).__init__() 29 30 def show(self, x, y, z): 31 return x, y, z 32 33 def construct(self, x, y, z): 34 f = partial(self.show, x) 35 ret = f(y, z) 36 return ret 37 38 x = Tensor(np.arange(3).reshape((3,)).astype(np.float32)) 39 y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32)) 40 z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)) 41 net = Net() 42 net(x, y, z) 43 44def test_partial_key_ward_arg(): 45 class Net(nn.Cell): 46 def __init__(self): 47 super(Net, self).__init__() 48 49 def show(self, x, y, z): 50 return x, y, z 51 52 def construct(self, x, y, z): 53 f = partial(self.show, x=x) 54 ret = f(y=y, z=z) 55 return ret 56 57 x = Tensor(np.arange(3).reshape((3,)).astype(np.float32)) 58 y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32)) 59 z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)) 60 net = Net() 61 net(x, y, z) 62 63def test_partial_key_ward_arg_update(): 64 class Net(nn.Cell): 65 def __init__(self): 66 super(Net, self).__init__() 67 68 def show(self, x, y, z): 69 return x, y, z 70 71 def construct(self, x, y, z): 72 f = partial(self.show, x=x, y=y) 73 ret = f(y=y, z=z) 74 return ret 75 76 x = Tensor(np.arange(3).reshape((3,)).astype(np.float32)) 77 y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32)) 78 z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)) 79 net = Net() 80 net(x, y, z) 81 82 83def test_partial_key_ward_arg_and_pos_arg(): 84 class Net(nn.Cell): 85 def __init__(self): 86 super(Net, self).__init__() 87 88 def show(self, x, y, z): 89 return x, y, z 90 91 def construct(self, x, y, z): 92 f = partial(self.show, y=y) 93 ret = f(2, z=z) 94 return ret 95 96 x = Tensor(np.arange(3).reshape((3,)).astype(np.float32)) 97 y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32)) 98 z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)) 99 net = Net() 100 net(x, y, z) 101 102 103def test_partial_pos_arg_const(): 104 class Net(nn.Cell): 105 def __init__(self): 106 super(Net, self).__init__() 107 108 def show(self, x, y, z): 109 return x, y, z 110 111 def construct(self): 112 f = partial(self.show, 1) 113 ret = f(2, 3) 114 return ret 115 116 net = Net() 117 assert net() == (1, 2, 3) 118 119def test_partial_key_ward_arg_const(): 120 class Net(nn.Cell): 121 def __init__(self): 122 super(Net, self).__init__() 123 124 def show(self, x, y, z): 125 return x, y, z 126 127 def construct(self): 128 f = partial(self.show, x=1) 129 ret = f(y=2, z=3) 130 return ret 131 132 net = Net() 133 assert net() == (1, 2, 3) 134 135def test_partial_key_ward_arg_update_const(): 136 class Net(nn.Cell): 137 def __init__(self): 138 super(Net, self).__init__() 139 140 def show(self, x, y, z): 141 return x, y, z 142 143 def construct(self): 144 f = partial(self.show, x=1, y=2) 145 ret = f(y=3, z=4) 146 return ret 147 148 net = Net() 149 assert net() == (1, 3, 4) 150 151 152def test_partial_key_ward_arg_and_pos_arg_const(): 153 class Net(nn.Cell): 154 def __init__(self): 155 super(Net, self).__init__() 156 157 def show(self, x, y, z): 158 return x, y, z 159 160 def construct(self): 161 f = partial(self.show, y=2) 162 ret = f(1, z=3) 163 return ret 164 165 net = Net() 166 assert net() == (1, 2, 3) 167 168 169def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_x(): 170 class Net(nn.Cell): 171 def __init__(self): 172 super(Net, self).__init__() 173 174 def show(self, x, y, z): 175 return x, y, z 176 177 def construct(self): 178 f = partial(self.show, x=1) 179 ret = f(1, 2, 3) 180 return ret 181 182 net = Net() 183 with pytest.raises(TypeError) as ex: 184 net() 185 assert "Multiply values for specific argument: x" in str(ex.value) 186 187 188def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_y(): 189 class Net(nn.Cell): 190 def __init__(self): 191 super(Net, self).__init__() 192 193 def show(self, x, y, z): 194 return x, y, z 195 196 def construct(self): 197 f = partial(self.show, y=2) 198 ret = f(1, 2, z=3) 199 return ret 200 201 net = Net() 202 with pytest.raises(TypeError) as ex: 203 net() 204 assert "Multiply values for specific argument: y" in str(ex.value) 205 206 207def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_z(): 208 class Net(nn.Cell): 209 def __init__(self): 210 super(Net, self).__init__() 211 212 def show(self, x, y, z): 213 return x, y, z 214 215 def construct(self): 216 f = partial(self.show, z=1) 217 ret = f(1, 2, 3) 218 return ret 219 220 net = Net() 221 with pytest.raises(TypeError) as ex: 222 net() 223 assert "Multiply values for specific argument: z" in str(ex.value) 224