1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test nn embedding """ 16import numpy as np 17import pytest 18 19from mindspore import Tensor 20from mindspore.common import dtype 21from mindspore.common.api import _cell_graph_executor 22from mindspore.nn import Embedding, MultiFieldEmbeddingLookup 23from ..ut_filter import non_graph_engine 24 25 26@non_graph_engine 27def test_check_embedding_1(): 28 net = Embedding(20000, 768, False) 29 input_data = Tensor(np.ones([8, 128]), dtype.int32) 30 _cell_graph_executor.compile(net, input_data) 31 32 33@non_graph_engine 34def test_check_embedding_2(): 35 net = Embedding(20000, 768, True) 36 input_data = Tensor(np.ones([8, 128]), dtype.int32) 37 _cell_graph_executor.compile(net, input_data) 38 39 40@non_graph_engine 41def test_check_embedding_3(): 42 net = Embedding(20000, 768, True, "zeros") 43 input_data = Tensor(np.ones([8, 128]), dtype.int32) 44 _cell_graph_executor.compile(net, input_data) 45 46 47def compile_multi_field_embedding(shape_id, shape_value, shape_field, 48 type_id, type_value, type_field): 49 net = MultiFieldEmbeddingLookup(20000, 768, 3) 50 input_data = Tensor(np.ones(shape_id), type_id) 51 input_value = Tensor(np.ones(shape_value), type_value) 52 input_field = Tensor(np.ones(shape_field), type_field) 53 _cell_graph_executor.compile(net, input_data, input_value, input_field) 54 55 56@non_graph_engine 57def test_check_multifield_embedding_right_type(): 58 compile_multi_field_embedding((8, 200), (8, 200), (8, 200), 59 dtype.int64, dtype.float32, dtype.int32) 60 61 62@non_graph_engine 63def test_check_multifield_embedding_false_type_input(): 64 with pytest.raises(TypeError): 65 compile_multi_field_embedding((8, 200), (8, 200), (8, 200), 66 dtype.int16, dtype.float32, dtype.int32) 67 68 69@non_graph_engine 70def test_check_multifield_embedding_false_type_value(): 71 with pytest.raises(TypeError): 72 compile_multi_field_embedding((8, 200), (8, 200), (8, 200), 73 dtype.int16, dtype.float16, dtype.int32) 74 75 76@non_graph_engine 77def test_check_multifield_embedding_false_type_field_id(): 78 with pytest.raises(TypeError): 79 compile_multi_field_embedding((8, 200), (8, 200), (8, 200), 80 dtype.int16, dtype.float32, dtype.int16) 81 82 83@non_graph_engine 84def test_check_multifield_embedding_false_input_shape(): 85 with pytest.raises(ValueError): 86 compile_multi_field_embedding((8,), (8, 200), (8, 200), 87 dtype.int16, dtype.float32, dtype.int16) 88 89 90@non_graph_engine 91def test_check_multifield_embedding_false_value_shape(): 92 with pytest.raises(ValueError): 93 compile_multi_field_embedding((8, 200), (8,), (8, 200), 94 dtype.int16, dtype.float32, dtype.int16) 95 96@non_graph_engine 97def test_print_embedding(): 98 net = Embedding(20000, 768, False) 99 print(net) 100