• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test nn embedding """
16import numpy as np
17import pytest
18
19from mindspore import Tensor
20from mindspore.common import dtype
21from mindspore.common.api import _cell_graph_executor
22from mindspore.nn import Embedding, MultiFieldEmbeddingLookup
23from ..ut_filter import non_graph_engine
24
25
26@non_graph_engine
27def test_check_embedding_1():
28    net = Embedding(20000, 768, False)
29    input_data = Tensor(np.ones([8, 128]), dtype.int32)
30    _cell_graph_executor.compile(net, input_data)
31
32
33@non_graph_engine
34def test_check_embedding_2():
35    net = Embedding(20000, 768, True)
36    input_data = Tensor(np.ones([8, 128]), dtype.int32)
37    _cell_graph_executor.compile(net, input_data)
38
39
40@non_graph_engine
41def test_check_embedding_3():
42    net = Embedding(20000, 768, True, "zeros")
43    input_data = Tensor(np.ones([8, 128]), dtype.int32)
44    _cell_graph_executor.compile(net, input_data)
45
46
47def compile_multi_field_embedding(shape_id, shape_value, shape_field,
48                                  type_id, type_value, type_field):
49    net = MultiFieldEmbeddingLookup(20000, 768, 3)
50    input_data = Tensor(np.ones(shape_id), type_id)
51    input_value = Tensor(np.ones(shape_value), type_value)
52    input_field = Tensor(np.ones(shape_field), type_field)
53    _cell_graph_executor.compile(net, input_data, input_value, input_field)
54
55
56@non_graph_engine
57def test_check_multifield_embedding_right_type():
58    compile_multi_field_embedding((8, 200), (8, 200), (8, 200),
59                                  dtype.int64, dtype.float32, dtype.int32)
60
61
62@non_graph_engine
63def test_check_multifield_embedding_false_type_input():
64    with pytest.raises(TypeError):
65        compile_multi_field_embedding((8, 200), (8, 200), (8, 200),
66                                      dtype.int16, dtype.float32, dtype.int32)
67
68
69@non_graph_engine
70def test_check_multifield_embedding_false_type_value():
71    with pytest.raises(TypeError):
72        compile_multi_field_embedding((8, 200), (8, 200), (8, 200),
73                                      dtype.int16, dtype.float16, dtype.int32)
74
75
76@non_graph_engine
77def test_check_multifield_embedding_false_type_field_id():
78    with pytest.raises(TypeError):
79        compile_multi_field_embedding((8, 200), (8, 200), (8, 200),
80                                      dtype.int16, dtype.float32, dtype.int16)
81
82
83@non_graph_engine
84def test_check_multifield_embedding_false_input_shape():
85    with pytest.raises(ValueError):
86        compile_multi_field_embedding((8,), (8, 200), (8, 200),
87                                      dtype.int16, dtype.float32, dtype.int16)
88
89
90@non_graph_engine
91def test_check_multifield_embedding_false_value_shape():
92    with pytest.raises(ValueError):
93        compile_multi_field_embedding((8, 200), (8,), (8, 200),
94                                      dtype.int16, dtype.float32, dtype.int16)
95
96@non_graph_engine
97def test_print_embedding():
98    net = Embedding(20000, 768, False)
99    print(net)
100