• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""embedding"""
16import mindspore.common.dtype as mstype
17from mindspore import log as logger
18from mindspore.common.tensor import Tensor
19from mindspore.ops import operations as P
20from mindspore.ops import functional as F
21from mindspore.common.parameter import Parameter
22from mindspore.common.initializer import initializer
23from mindspore.communication.management import get_group_size, get_rank
24from mindspore.context import ParallelMode
25from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch
26from mindspore.parallel._ps_context import _is_role_worker, _get_ps_context
27from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _set_rank_id
28from mindspore import context
29from mindspore._checkparam import Rel
30from mindspore._checkparam import Validator as validator
31from mindspore.ops.primitive import constexpr
32from .basic import ClipByNorm
33from .math import Range
34from ..cell import Cell
35
36__all__ = ['Embedding', 'EmbeddingLookup', 'MultiFieldEmbeddingLookup']
37
38
39@constexpr
40def _check_input_2d(input_shape, param_name, func_name):
41    if len(input_shape) != 2:
42        raise ValueError(f"For '{func_name}', the dimension of '{param_name}' should be 2d, but got {len(input_shape)}")
43    return True
44
45
46@constexpr
47def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
48    validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
49
50
51class Embedding(Cell):
52    r"""
53    A simple lookup table that stores embeddings of a fixed dictionary and size.
54
55    This module is often used to store word embeddings and retrieve them using
56    indices. The input to the module is a list of indices, and the output is
57    the corresponding word embeddings.
58
59    Note:
60        When 'use_one_hot' is set to True, the type of the `x` must be mindspore.int32.
61
62    Args:
63        vocab_size (int): Size of the dictionary of embeddings.
64        embedding_size (int): The size of each embedding vector.
65        use_one_hot (bool): Specifies whether to apply one_hot encoding form. Default: False.
66        embedding_table (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table.
67            Refer to class `initializer` for the values of string when a string
68            is specified. Default: 'normal'.
69        dtype (:class:`mindspore.dtype`): Data type of `x`. Default: mindspore.float32.
70        padding_idx (int, None): When the padding_idx encounters index, the output embedding vector of this index
71                                 will be initialized to zero. Default: None. The feature is inactivated.
72    Inputs:
73        - **x** (Tensor) - Tensor of shape :math:`(\text{batch_size}, \text{x_length})`. The elements of
74          the Tensor must be integer and not larger than vocab_size. Otherwise the corresponding embedding vector will
75          be zero. The data type is int32 or int64.
76
77    Outputs:
78        Tensor of shape :math:`(\text{batch_size}, \text{x_length}, \text{embedding_size})`.
79
80    Raises:
81        TypeError: If `vocab_size` or `embedding_size` is not an int.
82        TypeError: If `use_one_hot` is not a bool.
83        ValueError: If `padding_idx` is an int which not in range [0, `vocab_size`].
84
85    Supported Platforms:
86        ``Ascend`` ``GPU`` ``CPU``
87
88    Examples:
89        >>> net = nn.Embedding(20000, 768,  True)
90        >>> x = Tensor(np.ones([8, 128]), mindspore.int32)
91        >>> # Maps the input word IDs to word embedding.
92        >>> output = net(x)
93        >>> result = output.shape
94        >>> print(result)
95        (8, 128, 768)
96    """
97
98    def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal',
99                 dtype=mstype.float32, padding_idx=None):
100        """Initialize Embedding."""
101        super(Embedding, self).__init__()
102        self.vocab_size = validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name)
103        self.embedding_size = validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name)
104        validator.check_value_type('use_one_hot', use_one_hot, [bool], self.cls_name)
105        validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
106        self.use_one_hot = use_one_hot
107        self.dtype = dtype
108        self.init_tensor = initializer(embedding_table, [vocab_size, embedding_size])
109        self.padding_idx = padding_idx
110        if padding_idx is not None:
111            self.padding_idx = validator.check_int_range(padding_idx, 0, vocab_size, Rel.INC_BOTH,
112                                                         "padding_idx", self.cls_name)
113            if isinstance(self.init_tensor, Tensor) and self.init_tensor.init is not None:
114                self.init_tensor = self.init_tensor.init_data()
115            self.init_tensor = self.init_tensor.asnumpy()
116            self.init_tensor[self.padding_idx] = 0
117            self.init_tensor = Tensor(self.init_tensor)
118        self.embedding_table = Parameter(self.init_tensor, name='embedding_table')
119        self.expand = P.ExpandDims()
120        self.reshape_flat = P.Reshape()
121        self.shp_flat = (-1,)
122        self.gather = P.Gather()
123        self.one_hot = P.OneHot()
124        self.on_value = Tensor(1.0, self.dtype)
125        self.off_value = Tensor(0.0, self.dtype)
126        self.array_mul = P.MatMul()
127        self.reshape = P.Reshape()
128        self.get_shp = P.Shape()
129
130    def construct(self, ids):
131        extended_ids = self.expand(ids, -1)
132        out_shape = self.get_shp(ids) + (self.embedding_size,)
133        flat_ids = self.reshape_flat(extended_ids, self.shp_flat)
134
135        if self.use_one_hot:
136            one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
137            output_for_reshape = self.array_mul(one_hot_ids, self.embedding_table)
138        else:
139            output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
140
141        output = self.reshape(output_for_reshape, out_shape)
142        return output
143
144    def extend_repr(self):
145        s = 'vocab_size={}, embedding_size={}, use_one_hot={}, embedding_table={}, dtype={}, padding_idx={}'.format(
146            self.vocab_size, self.embedding_size, self.use_one_hot, self.embedding_table, self.dtype, self.padding_idx)
147        return s
148
149
150@constexpr
151def _make_axis_range(start, end):
152    axis = tuple(range(start, end))
153    return axis
154
155
156class EmbeddingLookup(Cell):
157    r"""
158    Returns a slice of the input tensor based on the specified indices.
159
160    Note:
161        When 'target' is set to 'CPU', this module will use
162        P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which
163        specified 'offset = 0' to lookup table.
164        When 'target' is set to 'DEVICE', this module will use P.Gather() which
165        specified 'axis = 0' to lookup table.
166        In field slice mode, the manual_shapes must be given. It is a tuple ,where
167        the element is vocab[i], vocab[i] is the row numbers for i-th part.
168
169    Args:
170        vocab_size (int): Size of the dictionary of embeddings.
171        embedding_size (int): The size of each embedding vector.
172        param_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table.
173            Refer to class `initializer` for the values of string when a string
174            is specified. Default: 'normal'.
175        target (str): Specifies the target where the op is executed. The value must in
176            ['DEVICE', 'CPU']. Default: 'CPU'.
177        slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value must get through
178            nn.EmbeddingLookup. Default: nn.EmbeddingLookup.BATCH_SLICE.
179        manual_shapes (tuple): The accompaniment array in field slice mode.
180        max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
181                                       or None. Default: None
182        sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True.
183        vocab_cache_size (int): Cache size of the dictionary of embeddings. Default: 0. It is valid only in
184            'DEVICE' target. And the moment parameter of corresponding optimizer will also be set to the cache size.
185            In addition, it should be noted that it will cost the 'DEVICE'
186            memory, so suggests setting a reasonable value to avoid insufficient memory.
187
188    Inputs:
189        - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
190          Specifies the indices of elements of the original Tensor. Values can be out of range of embedding_table,
191          and the exceeding part will be filled with 0 in the output. Values does not support negative and the result
192          is undefined if values are negative. Input_indices must only be a 2d tensor in
193          this interface when run in semi auto parallel/auto parallel mode.
194
195    Outputs:
196        Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
197
198    Raises:
199        TypeError: If `vocab_size` or `embedding_size` or `vocab_cache_size` is not an int.
200        TypeError: If `sparse` is not a bool or `manual_shapes` is not a tuple.
201        ValueError: If `vocab_size` or `embedding_size` is less than 1.
202        ValueError: If `vocab_cache_size` is less than 0.
203        ValueError: If `target` is neither 'CPU' nor 'DEVICE'.
204        ValueError: If `slice_mode` is not one of 'batch_slice' or 'field_slice' or
205                    'table_row_slice' or 'table_column_slice'.
206        ValueError: If `sparse` is False and `target` is 'CPU'.
207        ValueError: If `slice_mode` is 'field_slice' and `manual_shapes` is None.
208
209    Supported Platforms:
210        ``Ascend`` ``GPU`` ``CPU``
211
212    Examples:
213        >>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32)
214        >>> result = nn.EmbeddingLookup(4,2)(input_indices)
215        >>> print(result.shape)
216        (2, 2, 2)
217    """
218    BATCH_SLICE = "batch_slice"
219    FIELD_SLICE = "field_slice"
220    TABLE_ROW_SLICE = "table_row_slice"
221    TABLE_COLUMN_SLICE = "table_column_slice"
222
223    def __init__(self, vocab_size, embedding_size, param_init='normal',
224                 target='CPU', slice_mode='batch_slice', manual_shapes=None,
225                 max_norm=None, sparse=True, vocab_cache_size=0):
226        """Initialize EmbeddingLookup."""
227        super(EmbeddingLookup, self).__init__()
228        validator.check_value_type('sparse', sparse, [bool], self.cls_name)
229        self.vocab_size = validator.check_positive_int(vocab_size, 'vocab_size')
230        self.vocab_cache_size = validator.check_non_negative_int(vocab_cache_size, 'vocab_cache_size')
231        self.target = target
232        self.sparse = sparse
233        self.cache_enable = self.vocab_cache_size > 0
234        self.forward_unique = False
235        validator.check_string(target, ['CPU', 'DEVICE'], 'target', self.cls_name)
236        if not sparse and target == 'CPU':
237            raise ValueError(f"For '{self.cls_name}', 'sparse' must be True when 'target' is \"CPU\", "
238                             f"but got 'sparse': {sparse} and 'target': {target}")
239        if sparse:
240            self.gatherv2 = P.SparseGatherV2()
241        else:
242            self.gatherv2 = P.Gather()
243        self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
244        enable_ps = _get_ps_context("enable_ps")
245        if enable_ps:
246            self._process_vocab_cache(slice_mode)
247        self.embedding_size = validator.check_positive_int(embedding_size, 'embedding_size', self.cls_name)
248        self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]),
249                                         name='embedding_table')
250        parallel_mode = _get_parallel_mode()
251        is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
252        self.gather_revert = P.Gather()
253        self.reshape_first = P.Reshape()
254        self.reshape = P.Reshape()
255        self.unique = P.Unique()
256        self.shape = P.Shape()
257        if is_auto_parallel:
258            self.unique = P.Unique().shard(((1,),))
259        if self.cache_enable and enable_ps:
260            self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size)
261            if is_auto_parallel:
262                self.unique.add_prim_attr('cache_enable', True)
263        indices_shape_size = 2
264        if slice_mode == "field_slice" and is_auto_parallel:
265            if not manual_shapes:
266                raise ValueError(f"For '{self.cls_name}', the 'manual_shapes' should not be none "
267                                 f"when the 'slice_mode' is \"filed_slice\", but got {manual_shapes}.")
268            if not isinstance(manual_shapes, tuple):
269                raise TypeError(f"For '{self.cls_name}', the type of 'manual_shapes' must be tuple(int), "
270                                f"but got {type(manual_shapes).__name__}!")
271            for dim in manual_shapes:
272                validator.check_positive_int(dim, 'manual shape dim', self.cls_name)
273            self.gatherv2.add_prim_attr("manual_split", manual_shapes)
274            self.embeddinglookup.add_prim_attr("manual_split", manual_shapes)
275            self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size())))
276            self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size())))
277        elif slice_mode == "table_row_slice" and is_auto_parallel:
278            full_batch = _get_full_batch()
279            if (target == 'DEVICE' and not full_batch) or (self.cache_enable and enable_ps and sparse):
280                indices_shape_size = 1
281                self.gather_revert.shard(((1, 1), (get_group_size(),)))
282                self.forward_unique = True
283            indices_strategy = (1,)*indices_shape_size
284            self.gatherv2.shard(((get_group_size(), 1), indices_strategy))
285            self.embeddinglookup.shard(((get_group_size(), 1), indices_strategy))
286        elif slice_mode == "table_column_slice" and is_auto_parallel:
287            if target == 'DEVICE':
288                indices_shape_size = 1
289                self.gather_revert.shard(((1, get_group_size()), (1,)))
290                self.forward_unique = True
291            indices_strategy = (1,)*indices_shape_size
292            self.gatherv2.shard(((1, get_group_size()), indices_strategy))
293            self.embeddinglookup.shard(((1, get_group_size()), indices_strategy))
294        elif slice_mode == "batch_slice" and is_auto_parallel:
295            indices_strategy = [get_group_size()]
296            indices_strategy.extend([1]*(indices_shape_size - 1))
297            indices_strategy = tuple(indices_strategy)
298            self.gatherv2.shard(((1, 1), indices_strategy))
299            self.embeddinglookup.shard(((1, 1), indices_strategy))
300        else:
301            if is_auto_parallel:
302                support_mode = ["field_slice", "table_row_slice", "table_column_slice", "batch_slice"]
303                raise ValueError("For '{}', the 'slice_mode' must be in {}, "
304                                 "but got \"{}\".".format(self.cls_name, support_mode, slice_mode))
305        if self.cache_enable and not enable_ps:
306            if parallel_mode != ParallelMode.STAND_ALONE:
307                raise ValueError(f"For '{self.cls_name}', parallel mode haven't supported cache enable yet.")
308            self._set_cache_enable()
309        self.embedding_table.unique = self.forward_unique
310        self.max_norm = max_norm
311        if self.max_norm is not None:
312            self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name)
313            self.max_norm = Tensor(self.max_norm, dtype=mstype.float32)
314
315    def _set_cache_enable(self):
316        """EmbeddingLookup cache check for not ps env, which is only support 'ascend'."""
317        if self.target != 'DEVICE':
318            raise ValueError(f"For '{self.cls_name}', the configuration of 'vocab_cache_size' is valid only "
319                             f"when 'target' is 'DEVICE', but got 'target': {self.target}")
320        if not self.sparse:
321            raise ValueError(f"For '{self.cls_name}', the configuration of 'vocab_cache_size' is valid only "
322                             f"when 'sparse' is true, but got 'sparse': {self.sparse}.")
323        if context.get_context("device_target") != 'Ascend':
324            raise ValueError(f"For '{self.cls_name}', the configuration of 'vocab_cache_size' is valid only "
325                             f"when device target is 'Ascend', but got {context.get_context('device_target')}.")
326
327        logger.info("EmbeddingLookup cache enable takes effect.")
328        self.forward_unique = True
329        self.unique = P.Unique().add_prim_attr('primitive_target', 'CPU')
330        self.unique.add_prim_attr('cache_enable', True)
331        self.embedding_table.cache_enable = self.cache_enable
332        self.embedding_table.cache_shape = (self.vocab_cache_size, self.embedding_size)
333        self.reshape_first = P.Reshape().add_prim_attr('primitive_target', 'CPU')
334
335    def _process_vocab_cache(self, slice_mode):
336        """PS embeddingLookup cache check and process."""
337        self.cache_enable = False
338        if self.vocab_cache_size > 0:
339            if self.target == 'CPU':
340                logger.warning("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target, "
341                               "current target is CPU, so it will be ignored.")
342                return
343            enable_ps = _get_ps_context("enable_ps")
344            if not enable_ps:
345                logger.warning("The configuration of 'vocab_cache_size' is valid only in parameter server training "
346                               "mode, current mode is not parameter server trainning mode, so it will be ignored.")
347                return
348            parallel_mode = _get_parallel_mode()
349            is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
350            if is_auto_parallel:
351                rank_size = get_group_size()
352                rank_id = get_rank()
353                full_batch = _get_full_batch()
354                if rank_size > 1 and not (full_batch and slice_mode == "table_row_slice"):
355                    raise ValueError(f"For '{self.cls_name}', the cache of parameter server parallel should only be "
356                                     f"used in \"full_batch\" and the value of \"full_batch\" should be True. "
357                                     f"Meanwhile, the value of 'slice_mode' should be \"table_row_slice\"."
358                                     f"But got full_batch: {full_batch} and 'slice_mode': \"{slice_mode}\".")
359                self.vocab_cache_size = self.vocab_cache_size * rank_size
360                _set_rank_id(rank_id)
361            self.cache_enable = True
362            if _is_role_worker():
363                self.vocab_size = self.vocab_cache_size
364                if context.get_context("enable_sparse") != self.sparse:
365                    raise ValueError(f"For '{self.cls_name}', the value of parameter 'sparse' must be same for all "
366                                     f"kernels and equal the value of 'enable_sparse' in context setting in "
367                                     f"parameter server cache mode, but got value of parameter 'sparse': {self.sparse}"
368                                     f" and the 'enable_sparse' in context setting: "
369                                     f"{context.get_context('enable_sparse')}.")
370
371    def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size):
372        """PS embeddingLookup cache enable set."""
373        self.embedding_table.cache_enable = True
374        self.embedding_table.is_param_ps = True
375        _set_cache_enable(True)
376        if self.sparse:
377            self.forward_unique = True
378        if _is_role_worker():
379            _insert_hash_table_size(self.embedding_table.name, vocab_cache_size, embedding_size, vocab_size)
380
381    def construct(self, indices):
382        if self.target == "CPU":
383            out = self.embeddinglookup(self.embedding_table, indices, 0)
384        else:
385            if self.forward_unique:
386                shp = self.shape(indices) + (self.embedding_size,)
387                indices_flatten = self.reshape_first(indices, (-1,))
388                unique_id, unique_idx = self.unique(indices_flatten)
389                weight_unique = self.gatherv2(self.embedding_table, unique_id, 0)
390                weight_flatten = self.gather_revert(weight_unique, unique_idx, 0)
391                out = self.reshape(weight_flatten, shp)
392            else:
393                out = self.gatherv2(self.embedding_table, indices, 0)
394        if self.max_norm is not None:
395            axis = _make_axis_range(F.rank(indices), F.rank(out))
396            clip_by_norm = ClipByNorm(axis)
397            out = clip_by_norm(out, self.max_norm)
398        return out
399
400
401class MultiFieldEmbeddingLookup(EmbeddingLookup):
402    r"""
403    Returns a slice of input tensor based on the specified indices and the field ids. This operation
404    supports looking up embeddings using multi hot and one hot fields simultaneously.
405
406    Note:
407        When 'target' is set to 'CPU', this module will use
408        P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which
409        specified 'offset = 0' to lookup table.
410        When 'target' is set to 'DEVICE', this module will use P.Gather() which
411        specified 'axis = 0' to lookup table.
412        The vectors with the same field_ids  will be combined by the 'operator', such as 'SUM', 'MAX' and
413        'MEAN'. Ensure the input_values of the padded id is zero, so that they can be ignored. The final
414        output will be zeros if the sum of absolute weight of the field is zero. This class only
415        supports ['table_row_slice', 'batch_slice' and 'table_column_slice']. For the operation 'MAX' on
416        device Ascend, there is a constrain where batch_size * (seq_length + field_size) < 3500.
417
418    Args:
419        vocab_size (int): The size of the dictionary of embeddings.
420        embedding_size (int): The size of each embedding vector.
421        field_size (int): The field size of the final outputs.
422        param_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table.
423            Refer to class `initializer` for the values of string when a string
424            is specified. Default: 'normal'.
425        target (str): Specifies the target where the op is executed. The value must in
426            ['DEVICE', 'CPU']. Default: 'CPU'.
427        slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value must get through
428            nn.EmbeddingLookup. Default: nn.EmbeddingLookup.BATCH_SLICE.
429        feature_num_list (tuple): The accompaniment array in field slice mode. This is unused currently. Default: None.
430        max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
431                                       or None. Default: None
432        sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True.
433        operator (str): The pooling method for the features in one field. Support 'SUM', 'MEAN' and 'MAX'.
434            Default: 'SUM'.
435
436    Inputs:
437        - **input_indices** (Tensor) - The shape of tensor is :math:`(batch\_size, seq\_length)`.
438          Specifies the indices of elements of the original Tensor. Input_indices must be a 2d tensor in
439          this interface. Type is Int32, Int64.
440        - **input_values** (Tensor) - The shape of tensor is :math:`(batch\_size, seq\_length)`.
441          Specifies the weights of elements of the input_indices. The lookout vector will multiply with
442          the input_values. Type is Float32.
443        - **field_ids** (Tensor)  - The shape of tensor is :math:`(batch\_size, seq\_length)`.
444          Specifies the field id of elements of the input_indices. Type is Int32.
445
446    Outputs:
447        Tensor, the shape of tensor is :math:`(batch\_size, field\_size, embedding\_size)`. Type is Float32.
448
449    Raises:
450        TypeError: If `vocab_size` or `embedding_size` or `field_size` is not an int.
451        TypeError: If `sparse` is not a bool or `feature_num_list` is not a tuple.
452        ValueError: If `vocab_size` or `embedding_size` or `field_size` is less than 1.
453        ValueError: If `target` is neither 'CPU' nor 'DEVICE'.
454        ValueError: If `slice_mode` is not one of 'batch_slice', 'field_slice', 'table_row_slice',
455                    'table_column_slice'.
456        ValueError: If `sparse` is False and `target` is 'CPU'.
457        ValueError: If `slice_mode` is 'field_slice' and `feature_num_list` is None.
458        ValueError: If `operator` is not one of 'SUM', 'MAX', 'MEAN'.
459
460    Supported Platforms:
461        ``Ascend`` ``GPU``
462
463    Examples:
464        >>> input_indices = Tensor([[2, 4, 6, 0, 0], [1, 3, 5, 0, 0]], mindspore.int32)
465        >>> input_values = Tensor([[1, 1, 1, 0, 0], [1, 1, 1, 0, 0]], mindspore.float32)
466        >>> field_ids = Tensor([[0, 1, 1, 0, 0], [0, 0, 1, 0, 0]], mindspore.int32)
467        >>> net = nn.MultiFieldEmbeddingLookup(10, 2, field_size=2, operator='SUM', target='DEVICE')
468        >>> out = net(input_indices, input_values, field_ids)
469        >>> print(out.shape)
470        (2, 2, 2)
471    """
472    OPERATOR_SUM = 'SUM'
473    OPERATOR_MEAN = 'MEAN'
474    OPERATOR_MAX = 'MAX'
475
476    def __init__(self, vocab_size, embedding_size, field_size, param_init='normal', target='CPU',
477                 slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM'):
478        """Initialize MultiFieldEmbeddingLookup."""
479        super(MultiFieldEmbeddingLookup, self).__init__(vocab_size, embedding_size, param_init, target,
480                                                        slice_mode, feature_num_list, max_norm, sparse)
481        self.field_size = validator.check_positive_int(field_size, 'field_size', self.cls_name)
482        self.operator = operator
483
484        self.mul = P.Mul()
485        self.inf_mask_mul = P.Mul()
486        self.bias_add = P.Add()
487        self.inf_add = P.Add()
488        self.merge_op = None
489        self.count_op = P.UnsortedSegmentSum()
490        self.abs = P.Abs()
491        self.equal = P.Equal()
492        self.add = P.Add()
493        self.cast = P.Cast()
494        self.div_no_nan = P.DivNoNan()
495        self.expand = P.ExpandDims()
496        self.max_mask_mul = P.Mul()
497        self.max_no_equal = P.NotEqual()
498
499        validator.check_string(operator, ['SUM', 'MAX', 'MEAN'], 'operator', self.cls_name)
500        if operator == MultiFieldEmbeddingLookup.OPERATOR_SUM:
501            self.merge_op = P.UnsortedSegmentSum()
502        elif operator == MultiFieldEmbeddingLookup.OPERATOR_MAX:
503            self.merge_op = P.UnsortedSegmentMax()
504        else:
505            self.merge_op = P.UnsortedSegmentSum()
506
507
508        parallel_mode = _get_parallel_mode()
509        is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
510        if slice_mode in ["table_row_slice", "batch_slice"] and is_auto_parallel:
511            self.merge_op.shard(((get_group_size(), 1, 1), (get_group_size(), 1)))
512            self.expand.shard(((get_group_size(),),))
513            self.bias_add.shard(((1, 1), (1, 1)))
514            self.mul.shard(((get_group_size(), 1, 1), (get_group_size(), 1, 1)))
515            self.count_op.shard(((get_group_size(), 1), (get_group_size(), 1)))
516            self.add.shard(((get_group_size(),), (get_group_size(),)))
517            self.div_no_nan.shard(((get_group_size(), 1), (get_group_size(), 1)))
518            self.max_mask_mul.shard(((get_group_size(), 1), (get_group_size(), 1)))
519            self.max_no_equal.shard(((1,), ()))
520            if operator == MultiFieldEmbeddingLookup.OPERATOR_MAX:
521                self.equal.shard(((get_group_size(), 1, 1), ()))
522                self.inf_mask_mul.shard(((get_group_size(), 1, 1), ()))
523                self.merge_op.shard(((get_group_size(), 1), (get_group_size(),)))
524                self.count_op.shard(((get_group_size(),), (get_group_size(),)))
525                self.inf_add.shard(((get_group_size(), 1, 1), (get_group_size(), 1, 1)))
526        elif slice_mode == "table_column_slice" and is_auto_parallel:
527            self.merge_op.shard(((1, 1, get_group_size()), (1, 1)))
528            self.div_no_nan.shard(((1, get_group_size()), (1, 1)))
529            self.bias_add.shard(((1, 1), (1, 1)))
530            self.mul.shard(((1, 1, 1), (1, 1, get_group_size())))
531            self.count_op.shard(((1, 1), (1, 1)))
532            self.add.shard(((1,), (1,)))
533            self.max_mask_mul.shard(((1, get_group_size()), (1, 1)))
534            self.expand.shard(((1,),))
535            self.max_no_equal.shard(((1,), ()))
536            if operator == MultiFieldEmbeddingLookup.OPERATOR_MAX:
537                self.equal.shard(((1, 1, 1), ()))
538                self.inf_mask_mul.shard(((1, 1, 1), ()))
539                self.merge_op.shard(((1, get_group_size()), (1,)))
540                self.count_op.shard(((1,), (1,)))
541                self.inf_add.shard(((1, 1, get_group_size()), (1, 1, 1)))
542        else:
543            if is_auto_parallel:
544                raise ValueError("For '{}', the 'slice_mode' should be in ['table_row_slice', 'batch_slice' and \
545                                       'table_column_slice'], but got {}".format(self.cls_name, str(slice_mode)))
546
547        # Min value for fp32
548        self.negative_inf_value = -3.402823466E+38
549
550    def construct(self, input_indices, input_values, field_ids):
551
552        _check_input_2d(F.shape(input_indices), "input_indices", self.cls_name)
553        _check_input_2d(F.shape(input_values), "input_values", self.cls_name)
554        _check_input_2d(F.shape(field_ids), "field_ids", self.cls_name)
555        _check_input_dtype(F.dtype(input_indices), "input_indices", [mstype.int32, mstype.int64], self.cls_name)
556        _check_input_dtype(F.dtype(input_values), "input_values", [mstype.float32], self.cls_name)
557        _check_input_dtype(F.dtype(field_ids), "field_ids", [mstype.int32], self.cls_name)
558
559        batch_size = self.shape(input_indices)[0]
560        num_segments = batch_size * self.field_size
561        bias = Range(0, num_segments, self.field_size)()
562        bias = self.reshape(bias, (batch_size, -1))
563        field_ids = self.bias_add(field_ids, bias)
564
565        if self.target == "CPU":
566            out = self.embeddinglookup(self.embedding_table, input_indices, 0)
567        else:
568            if self.forward_unique:
569                shp = self.shape(input_indices) + (self.embedding_size,)
570                indices_flatten = self.reshape(input_indices, (-1,))
571                unique_id, unique_idx = self.unique(indices_flatten)
572                weight_unique = self.gatherv2(self.embedding_table, unique_id, 0)
573                weight_flatten = self.gather_revert(weight_unique, unique_idx, 0)
574                out = self.reshape(weight_flatten, shp)
575            else:
576                out = self.gatherv2(self.embedding_table, input_indices, 0)
577        if self.max_norm is not None:
578            axis = _make_axis_range(F.rank(input_indices), F.rank(out))
579            clip_by_norm = ClipByNorm(axis)
580            out = clip_by_norm(out, self.max_norm)
581
582        weights = self.reshape(input_values, (batch_size, self.shape(input_indices)[1], 1))
583        embedding = self.mul(weights, out)
584
585        if self.operator == 'MAX':
586            # Fill the padding value to -inf, so the padded value will not influence the results
587            negative_inf_mask = self.cast(self.equal(weights, 0), mstype.float32)
588            inf_mask = self.inf_mask_mul(negative_inf_mask, self.negative_inf_value)
589            embedding = self.inf_add(embedding, inf_mask)
590            embedding = self.reshape(embedding, (-1, self.embedding_size))
591            field_ids = self.reshape(field_ids, (-1,))
592
593        merged_vectors = self.merge_op(embedding, field_ids, num_segments)
594
595        if self.operator == 'MAX':
596            value_count = self.count_op(self.abs(self.reshape(input_values, (-1,))), field_ids, num_segments)
597            value_zeros = self.cast(self.max_no_equal(value_count, 0.0), mstype.float32)
598            count = self.expand(value_zeros, -1)
599            merged_vectors = self.max_mask_mul(merged_vectors, count)
600
601        if self.operator == 'MEAN':
602            value_count = self.count_op(self.abs(input_values), field_ids, num_segments)
603            value_count = self.expand(value_count, -1)
604            merged_vectors = self.div_no_nan(merged_vectors, value_count)
605
606        merged_vectors = self.reshape(merged_vectors, (batch_size, self.field_size, -1))
607        return merged_vectors
608