1# Copyright 2019-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import numpy as np 17import pytest 18 19import mindspore.context as context 20import mindspore.nn as nn 21from mindspore import Tensor 22from mindspore.ops.operations import _inner_ops as inner 23from mindspore.ops import operations as P 24 25 26class GatherNet(nn.Cell): 27 def __init__(self): 28 super(GatherNet, self).__init__() 29 self.gather = P.Gather() 30 31 def construct(self, x, indices): 32 return self.gather(x, indices, 1) 33 34 35@pytest.mark.level0 36@pytest.mark.platform_x86_gpu_training 37@pytest.mark.env_onecard 38def test_gather0(): 39 x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float32).reshape(2, 3, 4, 5)) 40 indices = Tensor(np.ones((2, 2, 4, 5), dtype='i4')) 41 expect = np.array([[[[[[[20., 21., 22., 23., 24.], 42 [25., 26., 27., 28., 29.], 43 [30., 31., 32., 33., 34.], 44 [35., 36., 37., 38., 39.]], 45 46 [[20., 21., 22., 23., 24.], 47 [25., 26., 27., 28., 29.], 48 [30., 31., 32., 33., 34.], 49 [35., 36., 37., 38., 39.]], 50 51 [[20., 21., 22., 23., 24.], 52 [25., 26., 27., 28., 29.], 53 [30., 31., 32., 33., 34.], 54 [35., 36., 37., 38., 39.]], 55 56 [[20., 21., 22., 23., 24.], 57 [25., 26., 27., 28., 29.], 58 [30., 31., 32., 33., 34.], 59 [35., 36., 37., 38., 39.]], 60 61 [[20., 21., 22., 23., 24.], 62 [25., 26., 27., 28., 29.], 63 [30., 31., 32., 33., 34.], 64 [35., 36., 37., 38., 39.]]], 65 66 [[[20., 21., 22., 23., 24.], 67 [25., 26., 27., 28., 29.], 68 [30., 31., 32., 33., 34.], 69 [35., 36., 37., 38., 39.]], 70 71 [[20., 21., 22., 23., 24.], 72 [25., 26., 27., 28., 29.], 73 [30., 31., 32., 33., 34.], 74 [35., 36., 37., 38., 39.]], 75 76 [[20., 21., 22., 23., 24.], 77 [25., 26., 27., 28., 29.], 78 [30., 31., 32., 33., 34.], 79 [35., 36., 37., 38., 39.]], 80 81 [[20., 21., 22., 23., 24.], 82 [25., 26., 27., 28., 29.], 83 [30., 31., 32., 33., 34.], 84 [35., 36., 37., 38., 39.]], 85 86 [[20., 21., 22., 23., 24.], 87 [25., 26., 27., 28., 29.], 88 [30., 31., 32., 33., 34.], 89 [35., 36., 37., 38., 39.]]], 90 91 [[[20., 21., 22., 23., 24.], 92 [25., 26., 27., 28., 29.], 93 [30., 31., 32., 33., 34.], 94 [35., 36., 37., 38., 39.]], 95 96 [[20., 21., 22., 23., 24.], 97 [25., 26., 27., 28., 29.], 98 [30., 31., 32., 33., 34.], 99 [35., 36., 37., 38., 39.]], 100 101 [[20., 21., 22., 23., 24.], 102 [25., 26., 27., 28., 29.], 103 [30., 31., 32., 33., 34.], 104 [35., 36., 37., 38., 39.]], 105 106 [[20., 21., 22., 23., 24.], 107 [25., 26., 27., 28., 29.], 108 [30., 31., 32., 33., 34.], 109 [35., 36., 37., 38., 39.]], 110 111 [[20., 21., 22., 23., 24.], 112 [25., 26., 27., 28., 29.], 113 [30., 31., 32., 33., 34.], 114 [35., 36., 37., 38., 39.]]], 115 116 [[[20., 21., 22., 23., 24.], 117 [25., 26., 27., 28., 29.], 118 [30., 31., 32., 33., 34.], 119 [35., 36., 37., 38., 39.]], 120 121 [[20., 21., 22., 23., 24.], 122 [25., 26., 27., 28., 29.], 123 [30., 31., 32., 33., 34.], 124 [35., 36., 37., 38., 39.]], 125 126 [[20., 21., 22., 23., 24.], 127 [25., 26., 27., 28., 29.], 128 [30., 31., 32., 33., 34.], 129 [35., 36., 37., 38., 39.]], 130 131 [[20., 21., 22., 23., 24.], 132 [25., 26., 27., 28., 29.], 133 [30., 31., 32., 33., 34.], 134 [35., 36., 37., 38., 39.]], 135 136 [[20., 21., 22., 23., 24.], 137 [25., 26., 27., 28., 29.], 138 [30., 31., 32., 33., 34.], 139 [35., 36., 37., 38., 39.]]]], 140 141 [[[[20., 21., 22., 23., 24.], 142 [25., 26., 27., 28., 29.], 143 [30., 31., 32., 33., 34.], 144 [35., 36., 37., 38., 39.]], 145 146 [[20., 21., 22., 23., 24.], 147 [25., 26., 27., 28., 29.], 148 [30., 31., 32., 33., 34.], 149 [35., 36., 37., 38., 39.]], 150 151 [[20., 21., 22., 23., 24.], 152 [25., 26., 27., 28., 29.], 153 [30., 31., 32., 33., 34.], 154 [35., 36., 37., 38., 39.]], 155 156 [[20., 21., 22., 23., 24.], 157 [25., 26., 27., 28., 29.], 158 [30., 31., 32., 33., 34.], 159 [35., 36., 37., 38., 39.]], 160 161 [[20., 21., 22., 23., 24.], 162 [25., 26., 27., 28., 29.], 163 [30., 31., 32., 33., 34.], 164 [35., 36., 37., 38., 39.]]], 165 166 [[[20., 21., 22., 23., 24.], 167 [25., 26., 27., 28., 29.], 168 [30., 31., 32., 33., 34.], 169 [35., 36., 37., 38., 39.]], 170 171 [[20., 21., 22., 23., 24.], 172 [25., 26., 27., 28., 29.], 173 [30., 31., 32., 33., 34.], 174 [35., 36., 37., 38., 39.]], 175 176 [[20., 21., 22., 23., 24.], 177 [25., 26., 27., 28., 29.], 178 [30., 31., 32., 33., 34.], 179 [35., 36., 37., 38., 39.]], 180 181 [[20., 21., 22., 23., 24.], 182 [25., 26., 27., 28., 29.], 183 [30., 31., 32., 33., 34.], 184 [35., 36., 37., 38., 39.]], 185 186 [[20., 21., 22., 23., 24.], 187 [25., 26., 27., 28., 29.], 188 [30., 31., 32., 33., 34.], 189 [35., 36., 37., 38., 39.]]], 190 191 [[[20., 21., 22., 23., 24.], 192 [25., 26., 27., 28., 29.], 193 [30., 31., 32., 33., 34.], 194 [35., 36., 37., 38., 39.]], 195 196 [[20., 21., 22., 23., 24.], 197 [25., 26., 27., 28., 29.], 198 [30., 31., 32., 33., 34.], 199 [35., 36., 37., 38., 39.]], 200 201 [[20., 21., 22., 23., 24.], 202 [25., 26., 27., 28., 29.], 203 [30., 31., 32., 33., 34.], 204 [35., 36., 37., 38., 39.]], 205 206 [[20., 21., 22., 23., 24.], 207 [25., 26., 27., 28., 29.], 208 [30., 31., 32., 33., 34.], 209 [35., 36., 37., 38., 39.]], 210 211 [[20., 21., 22., 23., 24.], 212 [25., 26., 27., 28., 29.], 213 [30., 31., 32., 33., 34.], 214 [35., 36., 37., 38., 39.]]], 215 216 [[[20., 21., 22., 23., 24.], 217 [25., 26., 27., 28., 29.], 218 [30., 31., 32., 33., 34.], 219 [35., 36., 37., 38., 39.]], 220 221 [[20., 21., 22., 23., 24.], 222 [25., 26., 27., 28., 29.], 223 [30., 31., 32., 33., 34.], 224 [35., 36., 37., 38., 39.]], 225 226 [[20., 21., 22., 23., 24.], 227 [25., 26., 27., 28., 29.], 228 [30., 31., 32., 33., 34.], 229 [35., 36., 37., 38., 39.]], 230 231 [[20., 21., 22., 23., 24.], 232 [25., 26., 27., 28., 29.], 233 [30., 31., 32., 33., 34.], 234 [35., 36., 37., 38., 39.]], 235 236 [[20., 21., 22., 23., 24.], 237 [25., 26., 27., 28., 29.], 238 [30., 31., 32., 33., 34.], 239 [35., 36., 37., 38., 39.]]]]], 240 241 [[[[[20., 21., 22., 23., 24.], 242 [25., 26., 27., 28., 29.], 243 [30., 31., 32., 33., 34.], 244 [35., 36., 37., 38., 39.]], 245 246 [[20., 21., 22., 23., 24.], 247 [25., 26., 27., 28., 29.], 248 [30., 31., 32., 33., 34.], 249 [35., 36., 37., 38., 39.]], 250 251 [[20., 21., 22., 23., 24.], 252 [25., 26., 27., 28., 29.], 253 [30., 31., 32., 33., 34.], 254 [35., 36., 37., 38., 39.]], 255 256 [[20., 21., 22., 23., 24.], 257 [25., 26., 27., 28., 29.], 258 [30., 31., 32., 33., 34.], 259 [35., 36., 37., 38., 39.]], 260 261 [[20., 21., 22., 23., 24.], 262 [25., 26., 27., 28., 29.], 263 [30., 31., 32., 33., 34.], 264 [35., 36., 37., 38., 39.]]], 265 266 [[[20., 21., 22., 23., 24.], 267 [25., 26., 27., 28., 29.], 268 [30., 31., 32., 33., 34.], 269 [35., 36., 37., 38., 39.]], 270 271 [[20., 21., 22., 23., 24.], 272 [25., 26., 27., 28., 29.], 273 [30., 31., 32., 33., 34.], 274 [35., 36., 37., 38., 39.]], 275 276 [[20., 21., 22., 23., 24.], 277 [25., 26., 27., 28., 29.], 278 [30., 31., 32., 33., 34.], 279 [35., 36., 37., 38., 39.]], 280 281 [[20., 21., 22., 23., 24.], 282 [25., 26., 27., 28., 29.], 283 [30., 31., 32., 33., 34.], 284 [35., 36., 37., 38., 39.]], 285 286 [[20., 21., 22., 23., 24.], 287 [25., 26., 27., 28., 29.], 288 [30., 31., 32., 33., 34.], 289 [35., 36., 37., 38., 39.]]], 290 291 [[[20., 21., 22., 23., 24.], 292 [25., 26., 27., 28., 29.], 293 [30., 31., 32., 33., 34.], 294 [35., 36., 37., 38., 39.]], 295 296 [[20., 21., 22., 23., 24.], 297 [25., 26., 27., 28., 29.], 298 [30., 31., 32., 33., 34.], 299 [35., 36., 37., 38., 39.]], 300 301 [[20., 21., 22., 23., 24.], 302 [25., 26., 27., 28., 29.], 303 [30., 31., 32., 33., 34.], 304 [35., 36., 37., 38., 39.]], 305 306 [[20., 21., 22., 23., 24.], 307 [25., 26., 27., 28., 29.], 308 [30., 31., 32., 33., 34.], 309 [35., 36., 37., 38., 39.]], 310 311 [[20., 21., 22., 23., 24.], 312 [25., 26., 27., 28., 29.], 313 [30., 31., 32., 33., 34.], 314 [35., 36., 37., 38., 39.]]], 315 316 [[[20., 21., 22., 23., 24.], 317 [25., 26., 27., 28., 29.], 318 [30., 31., 32., 33., 34.], 319 [35., 36., 37., 38., 39.]], 320 321 [[20., 21., 22., 23., 24.], 322 [25., 26., 27., 28., 29.], 323 [30., 31., 32., 33., 34.], 324 [35., 36., 37., 38., 39.]], 325 326 [[20., 21., 22., 23., 24.], 327 [25., 26., 27., 28., 29.], 328 [30., 31., 32., 33., 34.], 329 [35., 36., 37., 38., 39.]], 330 331 [[20., 21., 22., 23., 24.], 332 [25., 26., 27., 28., 29.], 333 [30., 31., 32., 33., 34.], 334 [35., 36., 37., 38., 39.]], 335 336 [[20., 21., 22., 23., 24.], 337 [25., 26., 27., 28., 29.], 338 [30., 31., 32., 33., 34.], 339 [35., 36., 37., 38., 39.]]]], 340 341 [[[[20., 21., 22., 23., 24.], 342 [25., 26., 27., 28., 29.], 343 [30., 31., 32., 33., 34.], 344 [35., 36., 37., 38., 39.]], 345 346 [[20., 21., 22., 23., 24.], 347 [25., 26., 27., 28., 29.], 348 [30., 31., 32., 33., 34.], 349 [35., 36., 37., 38., 39.]], 350 351 [[20., 21., 22., 23., 24.], 352 [25., 26., 27., 28., 29.], 353 [30., 31., 32., 33., 34.], 354 [35., 36., 37., 38., 39.]], 355 356 [[20., 21., 22., 23., 24.], 357 [25., 26., 27., 28., 29.], 358 [30., 31., 32., 33., 34.], 359 [35., 36., 37., 38., 39.]], 360 361 [[20., 21., 22., 23., 24.], 362 [25., 26., 27., 28., 29.], 363 [30., 31., 32., 33., 34.], 364 [35., 36., 37., 38., 39.]]], 365 366 [[[20., 21., 22., 23., 24.], 367 [25., 26., 27., 28., 29.], 368 [30., 31., 32., 33., 34.], 369 [35., 36., 37., 38., 39.]], 370 371 [[20., 21., 22., 23., 24.], 372 [25., 26., 27., 28., 29.], 373 [30., 31., 32., 33., 34.], 374 [35., 36., 37., 38., 39.]], 375 376 [[20., 21., 22., 23., 24.], 377 [25., 26., 27., 28., 29.], 378 [30., 31., 32., 33., 34.], 379 [35., 36., 37., 38., 39.]], 380 381 [[20., 21., 22., 23., 24.], 382 [25., 26., 27., 28., 29.], 383 [30., 31., 32., 33., 34.], 384 [35., 36., 37., 38., 39.]], 385 386 [[20., 21., 22., 23., 24.], 387 [25., 26., 27., 28., 29.], 388 [30., 31., 32., 33., 34.], 389 [35., 36., 37., 38., 39.]]], 390 391 [[[20., 21., 22., 23., 24.], 392 [25., 26., 27., 28., 29.], 393 [30., 31., 32., 33., 34.], 394 [35., 36., 37., 38., 39.]], 395 396 [[20., 21., 22., 23., 24.], 397 [25., 26., 27., 28., 29.], 398 [30., 31., 32., 33., 34.], 399 [35., 36., 37., 38., 39.]], 400 401 [[20., 21., 22., 23., 24.], 402 [25., 26., 27., 28., 29.], 403 [30., 31., 32., 33., 34.], 404 [35., 36., 37., 38., 39.]], 405 406 [[20., 21., 22., 23., 24.], 407 [25., 26., 27., 28., 29.], 408 [30., 31., 32., 33., 34.], 409 [35., 36., 37., 38., 39.]], 410 411 [[20., 21., 22., 23., 24.], 412 [25., 26., 27., 28., 29.], 413 [30., 31., 32., 33., 34.], 414 [35., 36., 37., 38., 39.]]], 415 416 [[[20., 21., 22., 23., 24.], 417 [25., 26., 27., 28., 29.], 418 [30., 31., 32., 33., 34.], 419 [35., 36., 37., 38., 39.]], 420 421 [[20., 21., 22., 23., 24.], 422 [25., 26., 27., 28., 29.], 423 [30., 31., 32., 33., 34.], 424 [35., 36., 37., 38., 39.]], 425 426 [[20., 21., 22., 23., 24.], 427 [25., 26., 27., 28., 29.], 428 [30., 31., 32., 33., 34.], 429 [35., 36., 37., 38., 39.]], 430 431 [[20., 21., 22., 23., 24.], 432 [25., 26., 27., 28., 29.], 433 [30., 31., 32., 33., 34.], 434 [35., 36., 37., 38., 39.]], 435 436 [[20., 21., 22., 23., 24.], 437 [25., 26., 27., 28., 29.], 438 [30., 31., 32., 33., 34.], 439 [35., 36., 37., 38., 39.]]]]]], 440 441 [[[[[[80., 81., 82., 83., 84.], 442 [85., 86., 87., 88., 89.], 443 [90., 91., 92., 93., 94.], 444 [95., 96., 97., 98., 99.]], 445 446 [[80., 81., 82., 83., 84.], 447 [85., 86., 87., 88., 89.], 448 [90., 91., 92., 93., 94.], 449 [95., 96., 97., 98., 99.]], 450 451 [[80., 81., 82., 83., 84.], 452 [85., 86., 87., 88., 89.], 453 [90., 91., 92., 93., 94.], 454 [95., 96., 97., 98., 99.]], 455 456 [[80., 81., 82., 83., 84.], 457 [85., 86., 87., 88., 89.], 458 [90., 91., 92., 93., 94.], 459 [95., 96., 97., 98., 99.]], 460 461 [[80., 81., 82., 83., 84.], 462 [85., 86., 87., 88., 89.], 463 [90., 91., 92., 93., 94.], 464 [95., 96., 97., 98., 99.]]], 465 466 [[[80., 81., 82., 83., 84.], 467 [85., 86., 87., 88., 89.], 468 [90., 91., 92., 93., 94.], 469 [95., 96., 97., 98., 99.]], 470 471 [[80., 81., 82., 83., 84.], 472 [85., 86., 87., 88., 89.], 473 [90., 91., 92., 93., 94.], 474 [95., 96., 97., 98., 99.]], 475 476 [[80., 81., 82., 83., 84.], 477 [85., 86., 87., 88., 89.], 478 [90., 91., 92., 93., 94.], 479 [95., 96., 97., 98., 99.]], 480 481 [[80., 81., 82., 83., 84.], 482 [85., 86., 87., 88., 89.], 483 [90., 91., 92., 93., 94.], 484 [95., 96., 97., 98., 99.]], 485 486 [[80., 81., 82., 83., 84.], 487 [85., 86., 87., 88., 89.], 488 [90., 91., 92., 93., 94.], 489 [95., 96., 97., 98., 99.]]], 490 491 [[[80., 81., 82., 83., 84.], 492 [85., 86., 87., 88., 89.], 493 [90., 91., 92., 93., 94.], 494 [95., 96., 97., 98., 99.]], 495 496 [[80., 81., 82., 83., 84.], 497 [85., 86., 87., 88., 89.], 498 [90., 91., 92., 93., 94.], 499 [95., 96., 97., 98., 99.]], 500 501 [[80., 81., 82., 83., 84.], 502 [85., 86., 87., 88., 89.], 503 [90., 91., 92., 93., 94.], 504 [95., 96., 97., 98., 99.]], 505 506 [[80., 81., 82., 83., 84.], 507 [85., 86., 87., 88., 89.], 508 [90., 91., 92., 93., 94.], 509 [95., 96., 97., 98., 99.]], 510 511 [[80., 81., 82., 83., 84.], 512 [85., 86., 87., 88., 89.], 513 [90., 91., 92., 93., 94.], 514 [95., 96., 97., 98., 99.]]], 515 516 [[[80., 81., 82., 83., 84.], 517 [85., 86., 87., 88., 89.], 518 [90., 91., 92., 93., 94.], 519 [95., 96., 97., 98., 99.]], 520 521 [[80., 81., 82., 83., 84.], 522 [85., 86., 87., 88., 89.], 523 [90., 91., 92., 93., 94.], 524 [95., 96., 97., 98., 99.]], 525 526 [[80., 81., 82., 83., 84.], 527 [85., 86., 87., 88., 89.], 528 [90., 91., 92., 93., 94.], 529 [95., 96., 97., 98., 99.]], 530 531 [[80., 81., 82., 83., 84.], 532 [85., 86., 87., 88., 89.], 533 [90., 91., 92., 93., 94.], 534 [95., 96., 97., 98., 99.]], 535 536 [[80., 81., 82., 83., 84.], 537 [85., 86., 87., 88., 89.], 538 [90., 91., 92., 93., 94.], 539 [95., 96., 97., 98., 99.]]]], 540 541 [[[[80., 81., 82., 83., 84.], 542 [85., 86., 87., 88., 89.], 543 [90., 91., 92., 93., 94.], 544 [95., 96., 97., 98., 99.]], 545 546 [[80., 81., 82., 83., 84.], 547 [85., 86., 87., 88., 89.], 548 [90., 91., 92., 93., 94.], 549 [95., 96., 97., 98., 99.]], 550 551 [[80., 81., 82., 83., 84.], 552 [85., 86., 87., 88., 89.], 553 [90., 91., 92., 93., 94.], 554 [95., 96., 97., 98., 99.]], 555 556 [[80., 81., 82., 83., 84.], 557 [85., 86., 87., 88., 89.], 558 [90., 91., 92., 93., 94.], 559 [95., 96., 97., 98., 99.]], 560 561 [[80., 81., 82., 83., 84.], 562 [85., 86., 87., 88., 89.], 563 [90., 91., 92., 93., 94.], 564 [95., 96., 97., 98., 99.]]], 565 566 [[[80., 81., 82., 83., 84.], 567 [85., 86., 87., 88., 89.], 568 [90., 91., 92., 93., 94.], 569 [95., 96., 97., 98., 99.]], 570 571 [[80., 81., 82., 83., 84.], 572 [85., 86., 87., 88., 89.], 573 [90., 91., 92., 93., 94.], 574 [95., 96., 97., 98., 99.]], 575 576 [[80., 81., 82., 83., 84.], 577 [85., 86., 87., 88., 89.], 578 [90., 91., 92., 93., 94.], 579 [95., 96., 97., 98., 99.]], 580 581 [[80., 81., 82., 83., 84.], 582 [85., 86., 87., 88., 89.], 583 [90., 91., 92., 93., 94.], 584 [95., 96., 97., 98., 99.]], 585 586 [[80., 81., 82., 83., 84.], 587 [85., 86., 87., 88., 89.], 588 [90., 91., 92., 93., 94.], 589 [95., 96., 97., 98., 99.]]], 590 591 [[[80., 81., 82., 83., 84.], 592 [85., 86., 87., 88., 89.], 593 [90., 91., 92., 93., 94.], 594 [95., 96., 97., 98., 99.]], 595 596 [[80., 81., 82., 83., 84.], 597 [85., 86., 87., 88., 89.], 598 [90., 91., 92., 93., 94.], 599 [95., 96., 97., 98., 99.]], 600 601 [[80., 81., 82., 83., 84.], 602 [85., 86., 87., 88., 89.], 603 [90., 91., 92., 93., 94.], 604 [95., 96., 97., 98., 99.]], 605 606 [[80., 81., 82., 83., 84.], 607 [85., 86., 87., 88., 89.], 608 [90., 91., 92., 93., 94.], 609 [95., 96., 97., 98., 99.]], 610 611 [[80., 81., 82., 83., 84.], 612 [85., 86., 87., 88., 89.], 613 [90., 91., 92., 93., 94.], 614 [95., 96., 97., 98., 99.]]], 615 616 [[[80., 81., 82., 83., 84.], 617 [85., 86., 87., 88., 89.], 618 [90., 91., 92., 93., 94.], 619 [95., 96., 97., 98., 99.]], 620 621 [[80., 81., 82., 83., 84.], 622 [85., 86., 87., 88., 89.], 623 [90., 91., 92., 93., 94.], 624 [95., 96., 97., 98., 99.]], 625 626 [[80., 81., 82., 83., 84.], 627 [85., 86., 87., 88., 89.], 628 [90., 91., 92., 93., 94.], 629 [95., 96., 97., 98., 99.]], 630 631 [[80., 81., 82., 83., 84.], 632 [85., 86., 87., 88., 89.], 633 [90., 91., 92., 93., 94.], 634 [95., 96., 97., 98., 99.]], 635 636 [[80., 81., 82., 83., 84.], 637 [85., 86., 87., 88., 89.], 638 [90., 91., 92., 93., 94.], 639 [95., 96., 97., 98., 99.]]]]], 640 641 [[[[[80., 81., 82., 83., 84.], 642 [85., 86., 87., 88., 89.], 643 [90., 91., 92., 93., 94.], 644 [95., 96., 97., 98., 99.]], 645 646 [[80., 81., 82., 83., 84.], 647 [85., 86., 87., 88., 89.], 648 [90., 91., 92., 93., 94.], 649 [95., 96., 97., 98., 99.]], 650 651 [[80., 81., 82., 83., 84.], 652 [85., 86., 87., 88., 89.], 653 [90., 91., 92., 93., 94.], 654 [95., 96., 97., 98., 99.]], 655 656 [[80., 81., 82., 83., 84.], 657 [85., 86., 87., 88., 89.], 658 [90., 91., 92., 93., 94.], 659 [95., 96., 97., 98., 99.]], 660 661 [[80., 81., 82., 83., 84.], 662 [85., 86., 87., 88., 89.], 663 [90., 91., 92., 93., 94.], 664 [95., 96., 97., 98., 99.]]], 665 666 [[[80., 81., 82., 83., 84.], 667 [85., 86., 87., 88., 89.], 668 [90., 91., 92., 93., 94.], 669 [95., 96., 97., 98., 99.]], 670 671 [[80., 81., 82., 83., 84.], 672 [85., 86., 87., 88., 89.], 673 [90., 91., 92., 93., 94.], 674 [95., 96., 97., 98., 99.]], 675 676 [[80., 81., 82., 83., 84.], 677 [85., 86., 87., 88., 89.], 678 [90., 91., 92., 93., 94.], 679 [95., 96., 97., 98., 99.]], 680 681 [[80., 81., 82., 83., 84.], 682 [85., 86., 87., 88., 89.], 683 [90., 91., 92., 93., 94.], 684 [95., 96., 97., 98., 99.]], 685 686 [[80., 81., 82., 83., 84.], 687 [85., 86., 87., 88., 89.], 688 [90., 91., 92., 93., 94.], 689 [95., 96., 97., 98., 99.]]], 690 691 [[[80., 81., 82., 83., 84.], 692 [85., 86., 87., 88., 89.], 693 [90., 91., 92., 93., 94.], 694 [95., 96., 97., 98., 99.]], 695 696 [[80., 81., 82., 83., 84.], 697 [85., 86., 87., 88., 89.], 698 [90., 91., 92., 93., 94.], 699 [95., 96., 97., 98., 99.]], 700 701 [[80., 81., 82., 83., 84.], 702 [85., 86., 87., 88., 89.], 703 [90., 91., 92., 93., 94.], 704 [95., 96., 97., 98., 99.]], 705 706 [[80., 81., 82., 83., 84.], 707 [85., 86., 87., 88., 89.], 708 [90., 91., 92., 93., 94.], 709 [95., 96., 97., 98., 99.]], 710 711 [[80., 81., 82., 83., 84.], 712 [85., 86., 87., 88., 89.], 713 [90., 91., 92., 93., 94.], 714 [95., 96., 97., 98., 99.]]], 715 716 [[[80., 81., 82., 83., 84.], 717 [85., 86., 87., 88., 89.], 718 [90., 91., 92., 93., 94.], 719 [95., 96., 97., 98., 99.]], 720 721 [[80., 81., 82., 83., 84.], 722 [85., 86., 87., 88., 89.], 723 [90., 91., 92., 93., 94.], 724 [95., 96., 97., 98., 99.]], 725 726 [[80., 81., 82., 83., 84.], 727 [85., 86., 87., 88., 89.], 728 [90., 91., 92., 93., 94.], 729 [95., 96., 97., 98., 99.]], 730 731 [[80., 81., 82., 83., 84.], 732 [85., 86., 87., 88., 89.], 733 [90., 91., 92., 93., 94.], 734 [95., 96., 97., 98., 99.]], 735 736 [[80., 81., 82., 83., 84.], 737 [85., 86., 87., 88., 89.], 738 [90., 91., 92., 93., 94.], 739 [95., 96., 97., 98., 99.]]]], 740 741 [[[[80., 81., 82., 83., 84.], 742 [85., 86., 87., 88., 89.], 743 [90., 91., 92., 93., 94.], 744 [95., 96., 97., 98., 99.]], 745 746 [[80., 81., 82., 83., 84.], 747 [85., 86., 87., 88., 89.], 748 [90., 91., 92., 93., 94.], 749 [95., 96., 97., 98., 99.]], 750 751 [[80., 81., 82., 83., 84.], 752 [85., 86., 87., 88., 89.], 753 [90., 91., 92., 93., 94.], 754 [95., 96., 97., 98., 99.]], 755 756 [[80., 81., 82., 83., 84.], 757 [85., 86., 87., 88., 89.], 758 [90., 91., 92., 93., 94.], 759 [95., 96., 97., 98., 99.]], 760 761 [[80., 81., 82., 83., 84.], 762 [85., 86., 87., 88., 89.], 763 [90., 91., 92., 93., 94.], 764 [95., 96., 97., 98., 99.]]], 765 766 [[[80., 81., 82., 83., 84.], 767 [85., 86., 87., 88., 89.], 768 [90., 91., 92., 93., 94.], 769 [95., 96., 97., 98., 99.]], 770 771 [[80., 81., 82., 83., 84.], 772 [85., 86., 87., 88., 89.], 773 [90., 91., 92., 93., 94.], 774 [95., 96., 97., 98., 99.]], 775 776 [[80., 81., 82., 83., 84.], 777 [85., 86., 87., 88., 89.], 778 [90., 91., 92., 93., 94.], 779 [95., 96., 97., 98., 99.]], 780 781 [[80., 81., 82., 83., 84.], 782 [85., 86., 87., 88., 89.], 783 [90., 91., 92., 93., 94.], 784 [95., 96., 97., 98., 99.]], 785 786 [[80., 81., 82., 83., 84.], 787 [85., 86., 87., 88., 89.], 788 [90., 91., 92., 93., 94.], 789 [95., 96., 97., 98., 99.]]], 790 791 [[[80., 81., 82., 83., 84.], 792 [85., 86., 87., 88., 89.], 793 [90., 91., 92., 93., 94.], 794 [95., 96., 97., 98., 99.]], 795 796 [[80., 81., 82., 83., 84.], 797 [85., 86., 87., 88., 89.], 798 [90., 91., 92., 93., 94.], 799 [95., 96., 97., 98., 99.]], 800 801 [[80., 81., 82., 83., 84.], 802 [85., 86., 87., 88., 89.], 803 [90., 91., 92., 93., 94.], 804 [95., 96., 97., 98., 99.]], 805 806 [[80., 81., 82., 83., 84.], 807 [85., 86., 87., 88., 89.], 808 [90., 91., 92., 93., 94.], 809 [95., 96., 97., 98., 99.]], 810 811 [[80., 81., 82., 83., 84.], 812 [85., 86., 87., 88., 89.], 813 [90., 91., 92., 93., 94.], 814 [95., 96., 97., 98., 99.]]], 815 816 [[[80., 81., 82., 83., 84.], 817 [85., 86., 87., 88., 89.], 818 [90., 91., 92., 93., 94.], 819 [95., 96., 97., 98., 99.]], 820 821 [[80., 81., 82., 83., 84.], 822 [85., 86., 87., 88., 89.], 823 [90., 91., 92., 93., 94.], 824 [95., 96., 97., 98., 99.]], 825 826 [[80., 81., 82., 83., 84.], 827 [85., 86., 87., 88., 89.], 828 [90., 91., 92., 93., 94.], 829 [95., 96., 97., 98., 99.]], 830 831 [[80., 81., 82., 83., 84.], 832 [85., 86., 87., 88., 89.], 833 [90., 91., 92., 93., 94.], 834 [95., 96., 97., 98., 99.]], 835 836 [[80., 81., 82., 83., 84.], 837 [85., 86., 87., 88., 89.], 838 [90., 91., 92., 93., 94.], 839 [95., 96., 97., 98., 99.]]]]]]]) 840 841 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 842 gather = GatherNet() 843 output = gather(x, indices) 844 error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 845 diff = output.asnumpy() - expect 846 assert np.all(diff < error) 847 assert np.all(-diff < error) 848 849 850class GatherNet1(nn.Cell): 851 def __init__(self): 852 super(GatherNet1, self).__init__() 853 self.gather = P.Gather() 854 855 def construct(self, x, indices): 856 return self.gather(x, indices, -1) 857 858 859@pytest.mark.level0 860@pytest.mark.platform_x86_gpu_training 861@pytest.mark.env_onecard 862def test_gather1(): 863 x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float32).reshape(2, 3, 4, 5)) 864 indices = Tensor(np.array([1, 3, 4], dtype='i4')) 865 expect = np.array([[[[1., 3., 4.], 866 [6., 8., 9.], 867 [11., 13., 14.], 868 [16., 18., 19.]], 869 870 [[21., 23., 24.], 871 [26., 28., 29.], 872 [31., 33., 34.], 873 [36., 38., 39.]], 874 875 [[41., 43., 44.], 876 [46., 48., 49.], 877 [51., 53., 54.], 878 [56., 58., 59.]]], 879 880 [[[61., 63., 64.], 881 [66., 68., 69.], 882 [71., 73., 74.], 883 [76., 78., 79.]], 884 885 [[81., 83., 84.], 886 [86., 88., 89.], 887 [91., 93., 94.], 888 [96., 98., 99.]], 889 890 [[101., 103., 104.], 891 [106., 108., 109.], 892 [111., 113., 114.], 893 [116., 118., 119.]]]]) 894 895 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 896 gather = GatherNet1() 897 output = gather(x, indices) 898 error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 899 diff = output.asnumpy() - expect 900 assert np.all(diff < error) 901 assert np.all(-diff < error) 902 903 904class GatherNet2(nn.Cell): 905 def __init__(self): 906 super(GatherNet2, self).__init__() 907 self.gather = P.Gather() 908 909 def construct(self, x, indices): 910 return self.gather(x, indices, 0) 911 912 913@pytest.mark.level0 914@pytest.mark.platform_x86_gpu_training 915@pytest.mark.env_onecard 916def test_gather2(): 917 x = Tensor(np.array([[4., 5., 4., 1., 5.], 918 [4., 9., 5., 6., 4.], 919 [9., 8., 4., 3., 6.], 920 [0., 4., 2., 2., 8.], 921 [1., 8., 6., 2., 8.], 922 [8., 1., 9., 7., 3.], 923 [7., 9., 2., 5., 7.], 924 [9., 8., 6., 8., 5.], 925 [3., 7., 2., 7., 4.], 926 [4., 2., 8., 2., 9.]] 927 ).astype(np.float32)) 928 929 indices = Tensor(np.array([[4000, 1, 300000]]).astype(np.int64)) 930 expect = np.array([[[0., 0., 0., 0., 0.], 931 [4., 9., 5., 6., 4.], 932 [0., 0., 0., 0., 0.]]]) 933 934 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 935 gather = GatherNet2() 936 output = gather(x, indices) 937 error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 938 diff = output.asnumpy() - expect 939 assert np.all(diff < error) 940 assert np.all(-diff < error) 941 942 943# Dynamic Shape testing ahead 944class GatherNetDynamic(nn.Cell): 945 def __init__(self, axis=0, dyn_a=True, dyn_b=True): 946 super(GatherNetDynamic, self).__init__() 947 self.gather = P.Gather() 948 self.gpu_convert_to_dynamic_shape = inner.GpuConvertToDynamicShape() 949 self.to_dyn_1 = dyn_a 950 self.to_dyn_2 = dyn_b 951 self.axis = axis 952 953 def construct(self, x, indices): 954 # testing selective inputs being dynamic 955 if self.to_dyn_1: 956 x = self.gpu_convert_to_dynamic_shape(x) 957 if self.to_dyn_2: 958 indices = self.gpu_convert_to_dynamic_shape(indices) 959 return self.gather(x, indices, self.axis) 960 961 962@pytest.mark.level0 963@pytest.mark.platform_x86_gpu_training 964@pytest.mark.env_onecard 965def test_gatherV2_dyn_ab(): 966 """ 967 Tests for Dynamic shape with both inputs dynamic 968 """ 969 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 970 gather = GatherNetDynamic() 971 x = Tensor(np.array([[4., 5., 4., 1., 5.], 972 [4., 9., 5., 6., 4.], 973 [9., 8., 4., 3., 6.], 974 [0., 4., 2., 2., 8.], 975 [1., 8., 6., 2., 8.], 976 [8., 1., 9., 7., 3.], 977 [7., 9., 2., 5., 7.], 978 [9., 8., 6., 8., 5.], 979 [3., 7., 2., 7., 4.], 980 [4., 2., 8., 2., 9.]] 981 ).astype(np.float32)) 982 indices = Tensor(np.array([[4000, 1, 300000]]).astype(np.int32)) 983 expect = np.array([[[0., 0., 0., 0., 0.], 984 [4., 9., 5., 6., 4.], 985 [0., 0., 0., 0., 0.]]]) 986 output = gather(x, indices) 987 error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 988 diff = output.asnumpy() - expect 989 assert np.all(diff < error) 990 assert np.all(-diff < error) 991 992 993@pytest.mark.level0 994@pytest.mark.platform_x86_gpu_training 995@pytest.mark.env_onecard 996def test_gatherV2_dyn_a(): 997 """ 998 Tests for Dynamic shape with only first input dynamic 999 """ 1000 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 1001 gather = GatherNetDynamic(-1, True, False) 1002 # test 1 1003 x = Tensor(np.array([[4., 5., 4., 1., 5.], 1004 [4., 9., 5., 6., 4.], 1005 [9., 8., 4., 3., 6.], 1006 [0., 4., 2., 2., 8.], 1007 [1., 8., 6., 2., 8.], 1008 [8., 1., 9., 7., 3.], 1009 [7., 9., 2., 5., 7.], 1010 [9., 8., 6., 8., 5.], 1011 [3., 7., 2., 7., 4.], 1012 [4., 2., 8., 2., 9.]] 1013 ).astype(np.float32)) 1014 indices = Tensor(np.array([[4000, 1, 300000]]).astype(np.int64)) 1015 expect = np.array([[[0., 5., 0.]], 1016 [[0., 9., 0.]], 1017 [[0., 8., 0.]], 1018 [[0., 4., 0.]], 1019 [[0., 8., 0.]], 1020 [[0., 1., 0.]], 1021 [[0., 9., 0.]], 1022 [[0., 8., 0.]], 1023 [[0., 7., 0.]], 1024 [[0., 2., 0.]]]).astype(np.float32) 1025 output = gather(x, indices) 1026 error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 1027 diff = output.asnumpy() - expect 1028 assert np.all(diff < error) 1029 assert np.all(-diff < error) 1030 # test 2 1031 x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float32).reshape(2, 3, 4, 5)) 1032 indices = Tensor(np.array([1, 3, 4], dtype='i4')) 1033 expect = np.array([[[[1., 3., 4.], 1034 [6., 8., 9.], 1035 [11., 13., 14.], 1036 [16., 18., 19.]], 1037 1038 [[21., 23., 24.], 1039 [26., 28., 29.], 1040 [31., 33., 34.], 1041 [36., 38., 39.]], 1042 1043 [[41., 43., 44.], 1044 [46., 48., 49.], 1045 [51., 53., 54.], 1046 [56., 58., 59.]]], 1047 1048 [[[61., 63., 64.], 1049 [66., 68., 69.], 1050 [71., 73., 74.], 1051 [76., 78., 79.]], 1052 1053 [[81., 83., 84.], 1054 [86., 88., 89.], 1055 [91., 93., 94.], 1056 [96., 98., 99.]], 1057 1058 [[101., 103., 104.], 1059 [106., 108., 109.], 1060 [111., 113., 114.], 1061 [116., 118., 119.]]]]) 1062 output = gather(x, indices) 1063 error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 1064 diff = output.asnumpy() - expect 1065 assert np.all(diff < error) 1066 assert np.all(-diff < error) 1067 1068 1069@pytest.mark.level0 1070@pytest.mark.platform_x86_gpu_training 1071@pytest.mark.env_onecard 1072def test_gatherV2_dyn_b(): 1073 """ 1074 Tests for Dynamic shape with only second input dynamic 1075 """ 1076 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 1077 gather = GatherNetDynamic(-1, False, True) 1078 # test 1 1079 x = Tensor(np.array([[4., 5., 4., 1., 5.], 1080 [4., 9., 5., 6., 4.], 1081 [9., 8., 4., 3., 6.], 1082 [0., 4., 2., 2., 8.], 1083 [1., 8., 6., 2., 8.], 1084 [8., 1., 9., 7., 3.], 1085 [7., 9., 2., 5., 7.], 1086 [9., 8., 6., 8., 5.], 1087 [3., 7., 2., 7., 4.], 1088 [4., 2., 8., 2., 9.]] 1089 ).astype(np.float32)) 1090 indices = Tensor(np.array([[4000, 1, 300000]]).astype(np.int32)) 1091 expect = np.array([[[0., 5., 0.]], 1092 [[0., 9., 0.]], 1093 [[0., 8., 0.]], 1094 [[0., 4., 0.]], 1095 [[0., 8., 0.]], 1096 [[0., 1., 0.]], 1097 [[0., 9., 0.]], 1098 [[0., 8., 0.]], 1099 [[0., 7., 0.]], 1100 [[0., 2., 0.]]]).astype(np.float32) 1101 output = gather(x, indices) 1102 error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 1103 diff = output.asnumpy() - expect 1104 assert np.all(diff < error) 1105 assert np.all(-diff < error) 1106 # test 2 1107 x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float32).reshape(2, 3, 4, 5)) 1108 indices = Tensor(np.array([1, 3, 4], dtype='i4')) 1109 expect = np.array([[[[1., 3., 4.], 1110 [6., 8., 9.], 1111 [11., 13., 14.], 1112 [16., 18., 19.]], 1113 [[21., 23., 24.], 1114 [26., 28., 29.], 1115 [31., 33., 34.], 1116 [36., 38., 39.]], 1117 [[41., 43., 44.], 1118 [46., 48., 49.], 1119 [51., 53., 54.], 1120 [56., 58., 59.]]], 1121 [[[61., 63., 64.], 1122 [66., 68., 69.], 1123 [71., 73., 74.], 1124 [76., 78., 79.]], 1125 [[81., 83., 84.], 1126 [86., 88., 89.], 1127 [91., 93., 94.], 1128 [96., 98., 99.]], 1129 [[101., 103., 104.], 1130 [106., 108., 109.], 1131 [111., 113., 114.], 1132 [116., 118., 119.]]]]) 1133 output = gather(x, indices) 1134 error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 1135 diff = output.asnumpy() - expect 1136 assert np.all(diff < error) 1137 assert np.all(-diff < error) 1138 1139 1140@pytest.mark.level0 1141@pytest.mark.platform_x86_gpu_training 1142@pytest.mark.env_onecard 1143def test_gather1_float64(): 1144 x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float64).reshape(2, 3, 4, 5)) 1145 indices = Tensor(np.array([1, 3, 4], dtype='i4')) 1146 expect = np.array([[[[1., 3., 4.], 1147 [6., 8., 9.], 1148 [11., 13., 14.], 1149 [16., 18., 19.]], 1150 1151 [[21., 23., 24.], 1152 [26., 28., 29.], 1153 [31., 33., 34.], 1154 [36., 38., 39.]], 1155 1156 [[41., 43., 44.], 1157 [46., 48., 49.], 1158 [51., 53., 54.], 1159 [56., 58., 59.]]], 1160 1161 [[[61., 63., 64.], 1162 [66., 68., 69.], 1163 [71., 73., 74.], 1164 [76., 78., 79.]], 1165 1166 [[81., 83., 84.], 1167 [86., 88., 89.], 1168 [91., 93., 94.], 1169 [96., 98., 99.]], 1170 1171 [[101., 103., 104.], 1172 [106., 108., 109.], 1173 [111., 113., 114.], 1174 [116., 118., 119.]]]]).astype(np.float64) 1175 1176 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 1177 gather = GatherNet1() 1178 output = gather(x, indices) 1179 error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 1180 diff = output.asnumpy() - expect 1181 assert np.all(diff < error) 1182 assert np.all(-diff < error) 1183 1184 1185@pytest.mark.level0 1186@pytest.mark.platform_x86_gpu_training 1187@pytest.mark.env_onecard 1188def test_gather1_int32(): 1189 x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.int32).reshape(2, 3, 4, 5)) 1190 indices = Tensor(np.array([1, 3, 4], dtype='i4')) 1191 expect = np.array([[[[1., 3., 4.], 1192 [6., 8., 9.], 1193 [11., 13., 14.], 1194 [16., 18., 19.]], 1195 1196 [[21., 23., 24.], 1197 [26., 28., 29.], 1198 [31., 33., 34.], 1199 [36., 38., 39.]], 1200 1201 [[41., 43., 44.], 1202 [46., 48., 49.], 1203 [51., 53., 54.], 1204 [56., 58., 59.]]], 1205 1206 [[[61., 63., 64.], 1207 [66., 68., 69.], 1208 [71., 73., 74.], 1209 [76., 78., 79.]], 1210 1211 [[81., 83., 84.], 1212 [86., 88., 89.], 1213 [91., 93., 94.], 1214 [96., 98., 99.]], 1215 1216 [[101., 103., 104.], 1217 [106., 108., 109.], 1218 [111., 113., 114.], 1219 [116., 118., 119.]]]]).astype(np.int32) 1220 1221 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 1222 gather = GatherNet1() 1223 output = gather(x, indices) 1224 error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 1225 diff = output.asnumpy() - expect 1226 assert np.all(diff < error) 1227 assert np.all(-diff < error) 1228 1229 1230@pytest.mark.level1 1231@pytest.mark.platform_x86_gpu_training 1232@pytest.mark.env_onecard 1233def test_gather1_int16(): 1234 x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.int16).reshape(2, 3, 4, 5)) 1235 indices = Tensor(np.array([1, 3, 4], dtype='i4')) 1236 expect = np.array([[[[1., 3., 4.], 1237 [6., 8., 9.], 1238 [11., 13., 14.], 1239 [16., 18., 19.]], 1240 1241 [[21., 23., 24.], 1242 [26., 28., 29.], 1243 [31., 33., 34.], 1244 [36., 38., 39.]], 1245 1246 [[41., 43., 44.], 1247 [46., 48., 49.], 1248 [51., 53., 54.], 1249 [56., 58., 59.]]], 1250 1251 [[[61., 63., 64.], 1252 [66., 68., 69.], 1253 [71., 73., 74.], 1254 [76., 78., 79.]], 1255 1256 [[81., 83., 84.], 1257 [86., 88., 89.], 1258 [91., 93., 94.], 1259 [96., 98., 99.]], 1260 1261 [[101., 103., 104.], 1262 [106., 108., 109.], 1263 [111., 113., 114.], 1264 [116., 118., 119.]]]]).astype(np.int16) 1265 1266 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 1267 gather = GatherNet1() 1268 output = gather(x, indices) 1269 error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 1270 diff = output.asnumpy() - expect 1271 assert np.all(diff < error) 1272 assert np.all(-diff < error) 1273 1274 1275@pytest.mark.level1 1276@pytest.mark.platform_x86_gpu_training 1277@pytest.mark.env_onecard 1278def test_gather1_int8(): 1279 x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.int8).reshape(2, 3, 4, 5)) 1280 indices = Tensor(np.array([1, 3, 4], dtype='i4')) 1281 expect = np.array([[[[1., 3., 4.], 1282 [6., 8., 9.], 1283 [11., 13., 14.], 1284 [16., 18., 19.]], 1285 1286 [[21., 23., 24.], 1287 [26., 28., 29.], 1288 [31., 33., 34.], 1289 [36., 38., 39.]], 1290 1291 [[41., 43., 44.], 1292 [46., 48., 49.], 1293 [51., 53., 54.], 1294 [56., 58., 59.]]], 1295 1296 [[[61., 63., 64.], 1297 [66., 68., 69.], 1298 [71., 73., 74.], 1299 [76., 78., 79.]], 1300 1301 [[81., 83., 84.], 1302 [86., 88., 89.], 1303 [91., 93., 94.], 1304 [96., 98., 99.]], 1305 1306 [[101., 103., 104.], 1307 [106., 108., 109.], 1308 [111., 113., 114.], 1309 [116., 118., 119.]]]]).astype(np.int8) 1310 1311 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 1312 gather = GatherNet1() 1313 output = gather(x, indices) 1314 error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 1315 diff = output.asnumpy() - expect 1316 assert np.all(diff < error) 1317 assert np.all(-diff < error) 1318 1319 1320@pytest.mark.level1 1321@pytest.mark.platform_x86_gpu_training 1322@pytest.mark.env_onecard 1323def test_gather1_uint8(): 1324 x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.uint8).reshape(2, 3, 4, 5)) 1325 indices = Tensor(np.array([1, 3, 4], dtype='i4')) 1326 expect = np.array([[[[1., 3., 4.], 1327 [6., 8., 9.], 1328 [11., 13., 14.], 1329 [16., 18., 19.]], 1330 1331 [[21., 23., 24.], 1332 [26., 28., 29.], 1333 [31., 33., 34.], 1334 [36., 38., 39.]], 1335 1336 [[41., 43., 44.], 1337 [46., 48., 49.], 1338 [51., 53., 54.], 1339 [56., 58., 59.]]], 1340 1341 [[[61., 63., 64.], 1342 [66., 68., 69.], 1343 [71., 73., 74.], 1344 [76., 78., 79.]], 1345 1346 [[81., 83., 84.], 1347 [86., 88., 89.], 1348 [91., 93., 94.], 1349 [96., 98., 99.]], 1350 1351 [[101., 103., 104.], 1352 [106., 108., 109.], 1353 [111., 113., 114.], 1354 [116., 118., 119.]]]]).astype(np.uint8) 1355 1356 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 1357 gather = GatherNet1() 1358 output = gather(x, indices) 1359 error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 1360 diff = output.asnumpy() - expect 1361 assert np.all(diff < error) 1362 assert np.all(-diff < error) 1363 1364 1365@pytest.mark.level1 1366@pytest.mark.platform_x86_gpu_training 1367@pytest.mark.env_onecard 1368def test_gather1_bool(): 1369 x = Tensor(np.array([[0, 1, 1, 0], [1, 0, 0, 0], [1, 0, 1, 0]], dtype=np.bool)) 1370 indices = Tensor(np.array(([1, 2]), dtype='i4')) 1371 expect = np.array([[1, 1], [0, 0], [0, 1]]).astype(np.bool) 1372 1373 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 1374 gather = GatherNet1() 1375 output = gather(x, indices) 1376 assert np.all(expect == output.asnumpy()) 1377