1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import numpy as np 17import pytest 18 19import mindspore.context as context 20import mindspore.nn as nn 21from mindspore import Tensor 22from mindspore.common.api import ms_function 23 24context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU') 25 26 27class NetOneHot(nn.Cell): 28 def __init__(self): 29 super(NetOneHot, self).__init__() 30 self.on_value = 2.0 31 self.off_value = 3.0 32 33 self.depth_1 = 6 34 self.one_hot_1 = nn.OneHot(-1, self.depth_1, self.on_value, self.off_value) 35 36 self.depth_2 = 4 37 self.one_hot_2 = nn.OneHot(0, self.depth_1, self.on_value, self.off_value) 38 self.one_hot_3 = nn.OneHot(0, self.depth_2, self.on_value, self.off_value) 39 self.one_hot_4 = nn.OneHot(1, self.depth_1, self.on_value, self.off_value) 40 41 @ms_function 42 def construct(self, indices1, indices2, indices3, indices4): 43 return (self.one_hot_1(indices1), self.one_hot_2(indices2), 44 self.one_hot_3(indices3), self.one_hot_4(indices4)) 45 46 47@pytest.mark.level0 48@pytest.mark.platform_x86_cpu 49@pytest.mark.env_onecard 50def test_one_hot(): 51 one_hot = NetOneHot() 52 indices1 = Tensor(np.array([[0, 1], [4, 5], [2, 6]]).astype(np.int32)) 53 indices2 = Tensor(np.array([1, 2, 3]).astype(np.int32)) 54 indices3 = Tensor(np.array([[0, 1], [1, 0]]).astype(np.int32)) 55 indices4 = Tensor(np.array([[0, 1], [4, 5], [2, 6]]).astype(np.int32)) 56 output = one_hot(indices1, indices2, indices3, indices4) 57 expect_0 = np.array([ 58 [[2., 3., 3., 3., 3., 3.], [3., 2., 3., 3., 3., 3.]], 59 [[3., 3., 3., 3., 2., 3.], [3., 3., 3., 3., 3., 2.]], 60 [[3., 3., 2., 3., 3., 3.], [3., 3., 3., 3., 3., 3.]] 61 ]).astype(np.float32) 62 expect_1 = np.array([ 63 [3., 3., 3.], 64 [2., 3., 3.], 65 [3., 2., 3.], 66 [3., 3., 2.], 67 [3., 3., 3.], 68 [3., 3., 3.] 69 ]).astype(np.float32) 70 expect_2 = np.array([ 71 [[2., 3.], [3., 2.]], [[3., 2.], [2., 3.]], [[3., 3.], [3., 3.]], 72 [[3., 3.], [3., 3.]] 73 ]).astype(np.float32) 74 expect_3 = np.array([ 75 [[2., 3.], [3., 2.], [3., 3.], [3., 3.], [3., 3.], [3., 3.]], 76 [[3., 3.], [3., 3.], [3., 3.], [3., 3.], [2., 3.], [3., 2.]], 77 [[3., 3.], [3., 3.], [2., 3.], [3., 3.], [3., 3.], [3., 3.]] 78 ]).astype(np.float32) 79 assert (output[0].asnumpy() == expect_0).all() 80 assert (output[1].asnumpy() == expect_1).all() 81 assert (output[2].asnumpy() == expect_2).all() 82 assert (output[3].asnumpy() == expect_3).all() 83