1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import numpy as np 17import pytest 18 19import mindspore.context as context 20import mindspore.nn as nn 21from mindspore import Tensor 22from mindspore.ops import operations as P 23 24 25class NetConv3dTranspose(nn.Cell): 26 def __init__(self): 27 super(NetConv3dTranspose, self).__init__() 28 in_channel = 2 29 out_channel = 2 30 kernel_size = 2 31 self.conv_trans = P.Conv3DTranspose(in_channel, out_channel, 32 kernel_size, 33 pad_mode="pad", 34 pad=1, 35 stride=1, 36 dilation=1, 37 group=1) 38 39 def construct(self, x, w): 40 return self.conv_trans(x, w) 41 42 43@pytest.mark.level0 44@pytest.mark.platform_x86_gpu_training 45@pytest.mark.env_onecard 46def test_conv3d_transpose(): 47 x = Tensor(np.arange(1 * 2 * 3 * 3 * 3).reshape(1, 2, 3, 3, 3).astype(np.float32)) 48 w = Tensor(np.ones((2, 2, 2, 2, 2)).astype(np.float32)) 49 expect = np.array([[[[[320., 336.], 50 [368., 384.]], 51 [[464., 480.], 52 [512., 528.]]], 53 [[[320., 336.], 54 [368., 384.]], 55 [[464., 480.], 56 [512., 528.]]]]]).astype(np.float32) 57 58 context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") 59 conv3dtranspose = NetConv3dTranspose() 60 output = conv3dtranspose(x, w) 61 assert (output.asnumpy() == expect).all() 62 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 63 conv3dtranspose = NetConv3dTranspose() 64 output = conv3dtranspose(x, w) 65 assert (output.asnumpy() == expect).all() 66