1# Copyright 2019-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import numpy as np 17import pytest 18 19import mindspore.context as context 20import mindspore.nn as nn 21from mindspore import Tensor 22from mindspore.ops import operations as P 23from mindspore.ops import composite as C 24 25class StridedSliceNet(nn.Cell): 26 def __init__(self, begin, end, stride, begin_mask=0, end_mask=0, ellipsis_mask=0): 27 super(StridedSliceNet, self).__init__() 28 self.begin = begin 29 self.end = end 30 self.strides = stride 31 self.slice = P.StridedSlice(begin_mask, end_mask, ellipsis_mask) 32 33 def construct(self, x): 34 return self.slice(x, self.begin, self.end, self.strides) 35 36class GradData(nn.Cell): 37 def __init__(self, network): 38 super(GradData, self).__init__() 39 self.grad = C.GradOperation(get_all=True, sens_param=False) 40 self.network = network 41 42 def construct(self, x): 43 return self.grad(self.network)(x) 44 45 46def strided_slice_grad(nptype): 47 context.set_context(mode=context.GRAPH_MODE, device_target='GPU') 48 49 x = Tensor(np.arange(0, 2*3*4*5).reshape(2, 3, 4, 5).astype(nptype)) 50 net = StridedSliceNet((1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1)) 51 dx = GradData(net)(x) 52 expect = np.array([[[[0., 0., 0., 0., 0.], 53 [0., 0., 0., 0., 0.], 54 [0., 0., 0., 0., 0.], 55 [0., 0., 0., 0., 0.]], 56 57 [[0., 0., 0., 0., 0.], 58 [0., 0., 0., 0., 0.], 59 [0., 0., 0., 0., 0.], 60 [0., 0., 0., 0., 0.]], 61 62 [[0., 0., 0., 0., 0.], 63 [0., 0., 0., 0., 0.], 64 [0., 0., 0., 0., 0.], 65 [0., 0., 0., 0., 0.]]], 66 67 68 [[[0., 0., 1., 1., 0.], 69 [0., 0., 1., 1., 0.], 70 [0., 0., 0., 0., 0.], 71 [0., 0., 0., 0., 0.]], 72 73 [[0., 0., 1., 1., 0.], 74 [0., 0., 1., 1., 0.], 75 [0., 0., 0., 0., 0.], 76 [0., 0., 0., 0., 0.]], 77 78 [[0., 0., 0., 0., 0.], 79 [0., 0., 0., 0., 0.], 80 [0., 0., 0., 0., 0.], 81 [0., 0., 0., 0., 0.]]]]).astype(nptype) 82 assert np.allclose(dx[0].asnumpy(), expect) 83 84 net = StridedSliceNet((1, 0, 0, 5), (2, 2, 2, 1), (1, 1, 1, -2)) 85 dx = GradData(net)(x) 86 expect = np.array([[[[0., 0., 0., 0., 0.], 87 [0., 0., 0., 0., 0.], 88 [0., 0., 0., 0., 0.], 89 [0., 0., 0., 0., 0.]], 90 91 [[0., 0., 0., 0., 0.], 92 [0., 0., 0., 0., 0.], 93 [0., 0., 0., 0., 0.], 94 [0., 0., 0., 0., 0.]], 95 96 [[0., 0., 0., 0., 0.], 97 [0., 0., 0., 0., 0.], 98 [0., 0., 0., 0., 0.], 99 [0., 0., 0., 0., 0.]]], 100 101 102 [[[0., 0., 1., 0., 1.], 103 [0., 0., 1., 0., 1.], 104 [0., 0., 0., 0., 0.], 105 [0., 0., 0., 0., 0.]], 106 107 [[0., 0., 1., 0., 1.], 108 [0., 0., 1., 0., 1.], 109 [0., 0., 0., 0., 0.], 110 [0., 0., 0., 0., 0.]], 111 112 [[0., 0., 0., 0., 0.], 113 [0., 0., 0., 0., 0.], 114 [0., 0., 0., 0., 0.], 115 [0., 0., 0., 0., 0.]]]]).astype(nptype) 116 assert np.allclose(dx[0].asnumpy(), expect) 117 118 119 net = StridedSliceNet((1, 0, 0, -1), (2, 2, 2, 1), (1, 1, 1, -1)) 120 dx = GradData(net)(x) 121 expect = np.array([[[[0., 0., 0., 0., 0.], 122 [0., 0., 0., 0., 0.], 123 [0., 0., 0., 0., 0.], 124 [0., 0., 0., 0., 0.]], 125 126 [[0., 0., 0., 0., 0.], 127 [0., 0., 0., 0., 0.], 128 [0., 0., 0., 0., 0.], 129 [0., 0., 0., 0., 0.]], 130 131 [[0., 0., 0., 0., 0.], 132 [0., 0., 0., 0., 0.], 133 [0., 0., 0., 0., 0.], 134 [0., 0., 0., 0., 0.]]], 135 136 137 [[[0., 0., 1., 1., 1.], 138 [0., 0., 1., 1., 1.], 139 [0., 0., 0., 0., 0.], 140 [0., 0., 0., 0., 0.]], 141 142 [[0., 0., 1., 1., 1.], 143 [0., 0., 1., 1., 1.], 144 [0., 0., 0., 0., 0.], 145 [0., 0., 0., 0., 0.]], 146 147 [[0., 0., 0., 0., 0.], 148 [0., 0., 0., 0., 0.], 149 [0., 0., 0., 0., 0.], 150 [0., 0., 0., 0., 0.]]]]).astype(nptype) 151 assert np.allclose(dx[0].asnumpy(), expect) 152 153 154 net = StridedSliceNet((1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1), 155 begin_mask=0b1000, end_mask=0b0010, ellipsis_mask=0b0100) 156 dx = GradData(net)(x) 157 expect = np.array([[[[0., 0., 0., 0., 0.], 158 [0., 0., 0., 0., 0.], 159 [0., 0., 0., 0., 0.], 160 [0., 0., 0., 0., 0.]], 161 162 [[0., 0., 0., 0., 0.], 163 [0., 0., 0., 0., 0.], 164 [0., 0., 0., 0., 0.], 165 [0., 0., 0., 0., 0.]], 166 167 [[0., 0., 0., 0., 0.], 168 [0., 0., 0., 0., 0.], 169 [0., 0., 0., 0., 0.], 170 [0., 0., 0., 0., 0.]]], 171 172 173 [[[1., 1., 1., 1., 0.], 174 [1., 1., 1., 1., 0.], 175 [1., 1., 1., 1., 0.], 176 [1., 1., 1., 1., 0.]], 177 178 [[1., 1., 1., 1., 0.], 179 [1., 1., 1., 1., 0.], 180 [1., 1., 1., 1., 0.], 181 [1., 1., 1., 1., 0.]], 182 183 [[1., 1., 1., 1., 0.], 184 [1., 1., 1., 1., 0.], 185 [1., 1., 1., 1., 0.], 186 [1., 1., 1., 1., 0.]]]]).astype(nptype) 187 assert np.allclose(dx[0].asnumpy(), expect) 188 189 x = Tensor(np.arange(0, 3*4*5).reshape(3, 4, 5).astype(np.float32)) 190 net = StridedSliceNet((1, 0, 0), (2, -3, 3), (1, 1, 3)) 191 dx = GradData(net)(x) 192 expect = np.array([[[0., 0., 0., 0., 0.], 193 [0., 0., 0., 0., 0.], 194 [0., 0., 0., 0., 0.], 195 [0., 0., 0., 0., 0.]], 196 197 [[1., 0., 0., 0., 0.], 198 [0., 0., 0., 0., 0.], 199 [0., 0., 0., 0., 0.], 200 [0., 0., 0., 0., 0.]], 201 202 [[0., 0., 0., 0., 0.], 203 [0., 0., 0., 0., 0.], 204 [0., 0., 0., 0., 0.], 205 [0., 0., 0., 0., 0.]]]).astype(nptype) 206 assert np.allclose(dx[0].asnumpy(), expect) 207 208 x = Tensor(np.arange(0, 1 * 1 * 1 * 2 * 3 * 4 * 5).reshape(1, 1, 1, 2, 3, 4, 5).astype(nptype)) 209 net = StridedSliceNet((0, 0, 0, 1, 1, 2, 2), (1, 1, 1, 2, 3, 3, 4), (1, 1, 1, 1, 1, 1, 1)) 210 dx = GradData(net)(x) 211 expect = np.array([[[[[[[0., 0., 0., 0., 0.], 212 [0., 0., 0., 0., 0.], 213 [0., 0., 0., 0., 0.], 214 [0., 0., 0., 0., 0.]], 215 216 [[0., 0., 0., 0., 0.], 217 [0., 0., 0., 0., 0.], 218 [0., 0., 0., 0., 0.], 219 [0., 0., 0., 0., 0.]], 220 221 [[0., 0., 0., 0., 0.], 222 [0., 0., 0., 0., 0.], 223 [0., 0., 0., 0., 0.], 224 [0., 0., 0., 0., 0.]]], 225 226 [[[0., 0., 0., 0., 0.], 227 [0., 0., 0., 0., 0.], 228 [0., 0., 0., 0., 0.], 229 [0., 0., 0., 0., 0.]], 230 231 [[0., 0., 0., 0., 0.], 232 [0., 0., 0., 0., 0.], 233 [0., 0., 1., 1., 0.], 234 [0., 0., 0., 0., 0.]], 235 236 [[0., 0., 0., 0., 0.], 237 [0., 0., 0., 0., 0.], 238 [0., 0., 1., 1., 0.], 239 [0., 0., 0., 0., 0.]]]]]]]).astype(nptype) 240 assert np.allclose(dx[0].asnumpy(), expect) 241 242@pytest.mark.level0 243@pytest.mark.platform_x86_gpu_training 244@pytest.mark.env_onecard 245def test_strided_slice_grad_float64(): 246 strided_slice_grad(np.float64) 247 248@pytest.mark.level0 249@pytest.mark.platform_x86_gpu_training 250@pytest.mark.env_onecard 251def test_strided_slice_grad_float32(): 252 strided_slice_grad(np.float32) 253 254@pytest.mark.level0 255@pytest.mark.platform_x86_gpu_training 256@pytest.mark.env_onecard 257def test_strided_slice_grad_float16(): 258 strided_slice_grad(np.float16) 259 260@pytest.mark.level0 261@pytest.mark.platform_x86_gpu_training 262@pytest.mark.env_onecard 263def test_strided_slice_grad_int64(): 264 strided_slice_grad(np.int64) 265 266@pytest.mark.level0 267@pytest.mark.platform_x86_gpu_training 268@pytest.mark.env_onecard 269def test_strided_slice_grad_int32(): 270 strided_slice_grad(np.int32) 271 272@pytest.mark.level0 273@pytest.mark.platform_x86_gpu_training 274@pytest.mark.env_onecard 275def test_strided_slice_grad_int16(): 276 strided_slice_grad(np.int16) 277 278@pytest.mark.level0 279@pytest.mark.platform_x86_gpu_training 280@pytest.mark.env_onecard 281def test_strided_slice_grad_int8(): 282 strided_slice_grad(np.int8) 283 284@pytest.mark.level0 285@pytest.mark.platform_x86_gpu_training 286@pytest.mark.env_onecard 287def test_strided_slice_grad_uint64(): 288 strided_slice_grad(np.uint64) 289 290@pytest.mark.level0 291@pytest.mark.platform_x86_gpu_training 292@pytest.mark.env_onecard 293def test_strided_slice_grad_uint32(): 294 strided_slice_grad(np.uint32) 295 296@pytest.mark.level0 297@pytest.mark.platform_x86_gpu_training 298@pytest.mark.env_onecard 299def test_strided_slice_grad_uint16(): 300 strided_slice_grad(np.uint16) 301 302@pytest.mark.level0 303@pytest.mark.platform_x86_gpu_training 304@pytest.mark.env_onecard 305def test_strided_slice_grad_uint8(): 306 strided_slice_grad(np.uint8) 307 308@pytest.mark.level0 309@pytest.mark.platform_x86_gpu_training 310@pytest.mark.env_onecard 311def test_strided_slice_grad_bool(): 312 strided_slice_grad(np.bool) 313