• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test_pynative_embeddinglookup """
16import pytest
17import numpy as np
18import mindspore.ops.operations as op
19from mindspore import Tensor, context
20from mindspore.nn import Cell
21
22def setup_module():
23    context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
24
25class MetaFactory:
26    def __init__(self):
27        self.device_target = context.get_context('device_target')
28        self.rank_size = None
29        self.device_id = None
30        self.global_rank_id = None
31
32class OpsFactory(MetaFactory):
33    def __init__(self, dtype=np.float16):
34        super().__init__()
35        self.dtype = dtype
36        if self.dtype == np.float16:
37            self.loss = 1e-3
38        elif self.dtype == np.float32:
39            self.loss = 1e-4
40        elif self.dtype == np.float64:
41            self.loss = 1e-5
42        else:
43            self.loss = 0
44
45class EmbeddingLookup(Cell):
46    def __init__(self, offset):
47        super().__init__()
48        self.op = op.EmbeddingLookup()
49        self.offset = offset
50
51    def construct(self, params, indices):
52        x = self.op(params, indices, self.offset)
53        return x
54
55class EmbeddingLookupFactory(OpsFactory):
56    def __init__(self, params_shape, indices_shape, offset=0, low=0, high=2, dtype=np.float32, ids_type=np.int32):
57        super().__init__(dtype=dtype)
58        self.input_np = np.random.randn(*params_shape).astype(dtype)
59        self.indices_np = np.random.randint(low, high, size=indices_shape).astype(ids_type)
60        self.offset = offset
61        self.output_grad_np = None
62
63    def forward_mindspore_impl(self):
64        net = EmbeddingLookup(self.offset)
65        out = net(Tensor(self.input_np), Tensor(self.indices_np))
66        return out.asnumpy()
67
68@pytest.mark.level0
69@pytest.mark.platform_arm_ascend_training
70@pytest.mark.platform_x86_ascend_training
71@pytest.mark.env_onecard
72def test_embeddinglookup_indices_outrange():
73    fact = EmbeddingLookupFactory(params_shape=(2, 4), indices_shape=(2, 3), low=1, high=3, offset=10, dtype=np.int8)
74    out = fact.forward_mindspore_impl()
75    out_expect = np.zeros((2, 3, 4))
76    np.allclose(out_expect, out)
77