1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test_pynative_embeddinglookup """ 16import pytest 17import numpy as np 18import mindspore.ops.operations as op 19from mindspore import Tensor, context 20from mindspore.nn import Cell 21 22def setup_module(): 23 context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") 24 25class MetaFactory: 26 def __init__(self): 27 self.device_target = context.get_context('device_target') 28 self.rank_size = None 29 self.device_id = None 30 self.global_rank_id = None 31 32class OpsFactory(MetaFactory): 33 def __init__(self, dtype=np.float16): 34 super().__init__() 35 self.dtype = dtype 36 if self.dtype == np.float16: 37 self.loss = 1e-3 38 elif self.dtype == np.float32: 39 self.loss = 1e-4 40 elif self.dtype == np.float64: 41 self.loss = 1e-5 42 else: 43 self.loss = 0 44 45class EmbeddingLookup(Cell): 46 def __init__(self, offset): 47 super().__init__() 48 self.op = op.EmbeddingLookup() 49 self.offset = offset 50 51 def construct(self, params, indices): 52 x = self.op(params, indices, self.offset) 53 return x 54 55class EmbeddingLookupFactory(OpsFactory): 56 def __init__(self, params_shape, indices_shape, offset=0, low=0, high=2, dtype=np.float32, ids_type=np.int32): 57 super().__init__(dtype=dtype) 58 self.input_np = np.random.randn(*params_shape).astype(dtype) 59 self.indices_np = np.random.randint(low, high, size=indices_shape).astype(ids_type) 60 self.offset = offset 61 self.output_grad_np = None 62 63 def forward_mindspore_impl(self): 64 net = EmbeddingLookup(self.offset) 65 out = net(Tensor(self.input_np), Tensor(self.indices_np)) 66 return out.asnumpy() 67 68@pytest.mark.level0 69@pytest.mark.platform_arm_ascend_training 70@pytest.mark.platform_x86_ascend_training 71@pytest.mark.env_onecard 72def test_embeddinglookup_indices_outrange(): 73 fact = EmbeddingLookupFactory(params_shape=(2, 4), indices_shape=(2, 3), low=1, high=3, offset=10, dtype=np.int8) 74 out = fact.forward_mindspore_impl() 75 out_expect = np.zeros((2, 3, 4)) 76 np.allclose(out_expect, out) 77