1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""embedding""" 16import mindspore.common.dtype as mstype 17from mindspore import log as logger 18from mindspore.common.tensor import Tensor 19from mindspore.ops import operations as P 20from mindspore.ops import functional as F 21from mindspore.common.parameter import Parameter 22from mindspore.common.initializer import initializer 23from mindspore.communication.management import get_group_size, get_rank 24from mindspore.context import ParallelMode 25from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch 26from mindspore.parallel._ps_context import _is_role_worker, _get_ps_context 27from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _set_rank_id 28from mindspore import context 29from mindspore._checkparam import Rel 30from mindspore._checkparam import Validator as validator 31from mindspore.ops.primitive import constexpr 32from .basic import ClipByNorm 33from .math import Range 34from ..cell import Cell 35 36__all__ = ['Embedding', 'EmbeddingLookup', 'MultiFieldEmbeddingLookup'] 37 38 39@constexpr 40def _check_input_2d(input_shape, param_name, func_name): 41 if len(input_shape) != 2: 42 raise ValueError(f"For '{func_name}', the dimension of '{param_name}' should be 2d, but got {len(input_shape)}") 43 return True 44 45 46@constexpr 47def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name): 48 validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name) 49 50 51class Embedding(Cell): 52 r""" 53 A simple lookup table that stores embeddings of a fixed dictionary and size. 54 55 This module is often used to store word embeddings and retrieve them using 56 indices. The input to the module is a list of indices, and the output is 57 the corresponding word embeddings. 58 59 Note: 60 When 'use_one_hot' is set to True, the type of the `x` must be mindspore.int32. 61 62 Args: 63 vocab_size (int): Size of the dictionary of embeddings. 64 embedding_size (int): The size of each embedding vector. 65 use_one_hot (bool): Specifies whether to apply one_hot encoding form. Default: False. 66 embedding_table (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table. 67 Refer to class `initializer` for the values of string when a string 68 is specified. Default: 'normal'. 69 dtype (:class:`mindspore.dtype`): Data type of `x`. Default: mindspore.float32. 70 padding_idx (int, None): When the padding_idx encounters index, the output embedding vector of this index 71 will be initialized to zero. Default: None. The feature is inactivated. 72 Inputs: 73 - **x** (Tensor) - Tensor of shape :math:`(\text{batch_size}, \text{x_length})`. The elements of 74 the Tensor must be integer and not larger than vocab_size. Otherwise the corresponding embedding vector will 75 be zero. The data type is int32 or int64. 76 77 Outputs: 78 Tensor of shape :math:`(\text{batch_size}, \text{x_length}, \text{embedding_size})`. 79 80 Raises: 81 TypeError: If `vocab_size` or `embedding_size` is not an int. 82 TypeError: If `use_one_hot` is not a bool. 83 ValueError: If `padding_idx` is an int which not in range [0, `vocab_size`]. 84 85 Supported Platforms: 86 ``Ascend`` ``GPU`` ``CPU`` 87 88 Examples: 89 >>> net = nn.Embedding(20000, 768, True) 90 >>> x = Tensor(np.ones([8, 128]), mindspore.int32) 91 >>> # Maps the input word IDs to word embedding. 92 >>> output = net(x) 93 >>> result = output.shape 94 >>> print(result) 95 (8, 128, 768) 96 """ 97 98 def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal', 99 dtype=mstype.float32, padding_idx=None): 100 """Initialize Embedding.""" 101 super(Embedding, self).__init__() 102 self.vocab_size = validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name) 103 self.embedding_size = validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name) 104 validator.check_value_type('use_one_hot', use_one_hot, [bool], self.cls_name) 105 validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) 106 self.use_one_hot = use_one_hot 107 self.dtype = dtype 108 self.init_tensor = initializer(embedding_table, [vocab_size, embedding_size]) 109 self.padding_idx = padding_idx 110 if padding_idx is not None: 111 self.padding_idx = validator.check_int_range(padding_idx, 0, vocab_size, Rel.INC_BOTH, 112 "padding_idx", self.cls_name) 113 if isinstance(self.init_tensor, Tensor) and self.init_tensor.init is not None: 114 self.init_tensor = self.init_tensor.init_data() 115 self.init_tensor = self.init_tensor.asnumpy() 116 self.init_tensor[self.padding_idx] = 0 117 self.init_tensor = Tensor(self.init_tensor) 118 self.embedding_table = Parameter(self.init_tensor, name='embedding_table') 119 self.expand = P.ExpandDims() 120 self.reshape_flat = P.Reshape() 121 self.shp_flat = (-1,) 122 self.gather = P.Gather() 123 self.one_hot = P.OneHot() 124 self.on_value = Tensor(1.0, self.dtype) 125 self.off_value = Tensor(0.0, self.dtype) 126 self.array_mul = P.MatMul() 127 self.reshape = P.Reshape() 128 self.get_shp = P.Shape() 129 130 def construct(self, ids): 131 extended_ids = self.expand(ids, -1) 132 out_shape = self.get_shp(ids) + (self.embedding_size,) 133 flat_ids = self.reshape_flat(extended_ids, self.shp_flat) 134 135 if self.use_one_hot: 136 one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value) 137 output_for_reshape = self.array_mul(one_hot_ids, self.embedding_table) 138 else: 139 output_for_reshape = self.gather(self.embedding_table, flat_ids, 0) 140 141 output = self.reshape(output_for_reshape, out_shape) 142 return output 143 144 def extend_repr(self): 145 s = 'vocab_size={}, embedding_size={}, use_one_hot={}, embedding_table={}, dtype={}, padding_idx={}'.format( 146 self.vocab_size, self.embedding_size, self.use_one_hot, self.embedding_table, self.dtype, self.padding_idx) 147 return s 148 149 150@constexpr 151def _make_axis_range(start, end): 152 axis = tuple(range(start, end)) 153 return axis 154 155 156class EmbeddingLookup(Cell): 157 r""" 158 Returns a slice of the input tensor based on the specified indices. 159 160 Note: 161 When 'target' is set to 'CPU', this module will use 162 P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which 163 specified 'offset = 0' to lookup table. 164 When 'target' is set to 'DEVICE', this module will use P.Gather() which 165 specified 'axis = 0' to lookup table. 166 In field slice mode, the manual_shapes must be given. It is a tuple ,where 167 the element is vocab[i], vocab[i] is the row numbers for i-th part. 168 169 Args: 170 vocab_size (int): Size of the dictionary of embeddings. 171 embedding_size (int): The size of each embedding vector. 172 param_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table. 173 Refer to class `initializer` for the values of string when a string 174 is specified. Default: 'normal'. 175 target (str): Specifies the target where the op is executed. The value must in 176 ['DEVICE', 'CPU']. Default: 'CPU'. 177 slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value must get through 178 nn.EmbeddingLookup. Default: nn.EmbeddingLookup.BATCH_SLICE. 179 manual_shapes (tuple): The accompaniment array in field slice mode. 180 max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32 181 or None. Default: None 182 sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True. 183 vocab_cache_size (int): Cache size of the dictionary of embeddings. Default: 0. It is valid only in 184 'DEVICE' target. And the moment parameter of corresponding optimizer will also be set to the cache size. 185 In addition, it should be noted that it will cost the 'DEVICE' 186 memory, so suggests setting a reasonable value to avoid insufficient memory. 187 188 Inputs: 189 - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. 190 Specifies the indices of elements of the original Tensor. Values can be out of range of embedding_table, 191 and the exceeding part will be filled with 0 in the output. Values does not support negative and the result 192 is undefined if values are negative. Input_indices must only be a 2d tensor in 193 this interface when run in semi auto parallel/auto parallel mode. 194 195 Outputs: 196 Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`. 197 198 Raises: 199 TypeError: If `vocab_size` or `embedding_size` or `vocab_cache_size` is not an int. 200 TypeError: If `sparse` is not a bool or `manual_shapes` is not a tuple. 201 ValueError: If `vocab_size` or `embedding_size` is less than 1. 202 ValueError: If `vocab_cache_size` is less than 0. 203 ValueError: If `target` is neither 'CPU' nor 'DEVICE'. 204 ValueError: If `slice_mode` is not one of 'batch_slice' or 'field_slice' or 205 'table_row_slice' or 'table_column_slice'. 206 ValueError: If `sparse` is False and `target` is 'CPU'. 207 ValueError: If `slice_mode` is 'field_slice' and `manual_shapes` is None. 208 209 Supported Platforms: 210 ``Ascend`` ``GPU`` ``CPU`` 211 212 Examples: 213 >>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32) 214 >>> result = nn.EmbeddingLookup(4,2)(input_indices) 215 >>> print(result.shape) 216 (2, 2, 2) 217 """ 218 BATCH_SLICE = "batch_slice" 219 FIELD_SLICE = "field_slice" 220 TABLE_ROW_SLICE = "table_row_slice" 221 TABLE_COLUMN_SLICE = "table_column_slice" 222 223 def __init__(self, vocab_size, embedding_size, param_init='normal', 224 target='CPU', slice_mode='batch_slice', manual_shapes=None, 225 max_norm=None, sparse=True, vocab_cache_size=0): 226 """Initialize EmbeddingLookup.""" 227 super(EmbeddingLookup, self).__init__() 228 validator.check_value_type('sparse', sparse, [bool], self.cls_name) 229 self.vocab_size = validator.check_positive_int(vocab_size, 'vocab_size') 230 self.vocab_cache_size = validator.check_non_negative_int(vocab_cache_size, 'vocab_cache_size') 231 self.target = target 232 self.sparse = sparse 233 self.cache_enable = self.vocab_cache_size > 0 234 self.forward_unique = False 235 validator.check_string(target, ['CPU', 'DEVICE'], 'target', self.cls_name) 236 if not sparse and target == 'CPU': 237 raise ValueError(f"For '{self.cls_name}', 'sparse' must be True when 'target' is \"CPU\", " 238 f"but got 'sparse': {sparse} and 'target': {target}") 239 if sparse: 240 self.gatherv2 = P.SparseGatherV2() 241 else: 242 self.gatherv2 = P.Gather() 243 self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') 244 enable_ps = _get_ps_context("enable_ps") 245 if enable_ps: 246 self._process_vocab_cache(slice_mode) 247 self.embedding_size = validator.check_positive_int(embedding_size, 'embedding_size', self.cls_name) 248 self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]), 249 name='embedding_table') 250 parallel_mode = _get_parallel_mode() 251 is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) 252 self.gather_revert = P.Gather() 253 self.reshape_first = P.Reshape() 254 self.reshape = P.Reshape() 255 self.unique = P.Unique() 256 self.shape = P.Shape() 257 if is_auto_parallel: 258 self.unique = P.Unique().shard(((1,),)) 259 if self.cache_enable and enable_ps: 260 self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size) 261 if is_auto_parallel: 262 self.unique.add_prim_attr('cache_enable', True) 263 indices_shape_size = 2 264 if slice_mode == "field_slice" and is_auto_parallel: 265 if not manual_shapes: 266 raise ValueError(f"For '{self.cls_name}', the 'manual_shapes' should not be none " 267 f"when the 'slice_mode' is \"filed_slice\", but got {manual_shapes}.") 268 if not isinstance(manual_shapes, tuple): 269 raise TypeError(f"For '{self.cls_name}', the type of 'manual_shapes' must be tuple(int), " 270 f"but got {type(manual_shapes).__name__}!") 271 for dim in manual_shapes: 272 validator.check_positive_int(dim, 'manual shape dim', self.cls_name) 273 self.gatherv2.add_prim_attr("manual_split", manual_shapes) 274 self.embeddinglookup.add_prim_attr("manual_split", manual_shapes) 275 self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size()))) 276 self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size()))) 277 elif slice_mode == "table_row_slice" and is_auto_parallel: 278 full_batch = _get_full_batch() 279 if (target == 'DEVICE' and not full_batch) or (self.cache_enable and enable_ps and sparse): 280 indices_shape_size = 1 281 self.gather_revert.shard(((1, 1), (get_group_size(),))) 282 self.forward_unique = True 283 indices_strategy = (1,)*indices_shape_size 284 self.gatherv2.shard(((get_group_size(), 1), indices_strategy)) 285 self.embeddinglookup.shard(((get_group_size(), 1), indices_strategy)) 286 elif slice_mode == "table_column_slice" and is_auto_parallel: 287 if target == 'DEVICE': 288 indices_shape_size = 1 289 self.gather_revert.shard(((1, get_group_size()), (1,))) 290 self.forward_unique = True 291 indices_strategy = (1,)*indices_shape_size 292 self.gatherv2.shard(((1, get_group_size()), indices_strategy)) 293 self.embeddinglookup.shard(((1, get_group_size()), indices_strategy)) 294 elif slice_mode == "batch_slice" and is_auto_parallel: 295 indices_strategy = [get_group_size()] 296 indices_strategy.extend([1]*(indices_shape_size - 1)) 297 indices_strategy = tuple(indices_strategy) 298 self.gatherv2.shard(((1, 1), indices_strategy)) 299 self.embeddinglookup.shard(((1, 1), indices_strategy)) 300 else: 301 if is_auto_parallel: 302 support_mode = ["field_slice", "table_row_slice", "table_column_slice", "batch_slice"] 303 raise ValueError("For '{}', the 'slice_mode' must be in {}, " 304 "but got \"{}\".".format(self.cls_name, support_mode, slice_mode)) 305 if self.cache_enable and not enable_ps: 306 if parallel_mode != ParallelMode.STAND_ALONE: 307 raise ValueError(f"For '{self.cls_name}', parallel mode haven't supported cache enable yet.") 308 self._set_cache_enable() 309 self.embedding_table.unique = self.forward_unique 310 self.max_norm = max_norm 311 if self.max_norm is not None: 312 self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name) 313 self.max_norm = Tensor(self.max_norm, dtype=mstype.float32) 314 315 def _set_cache_enable(self): 316 """EmbeddingLookup cache check for not ps env, which is only support 'ascend'.""" 317 if self.target != 'DEVICE': 318 raise ValueError(f"For '{self.cls_name}', the configuration of 'vocab_cache_size' is valid only " 319 f"when 'target' is 'DEVICE', but got 'target': {self.target}") 320 if not self.sparse: 321 raise ValueError(f"For '{self.cls_name}', the configuration of 'vocab_cache_size' is valid only " 322 f"when 'sparse' is true, but got 'sparse': {self.sparse}.") 323 if context.get_context("device_target") != 'Ascend': 324 raise ValueError(f"For '{self.cls_name}', the configuration of 'vocab_cache_size' is valid only " 325 f"when device target is 'Ascend', but got {context.get_context('device_target')}.") 326 327 logger.info("EmbeddingLookup cache enable takes effect.") 328 self.forward_unique = True 329 self.unique = P.Unique().add_prim_attr('primitive_target', 'CPU') 330 self.unique.add_prim_attr('cache_enable', True) 331 self.embedding_table.cache_enable = self.cache_enable 332 self.embedding_table.cache_shape = (self.vocab_cache_size, self.embedding_size) 333 self.reshape_first = P.Reshape().add_prim_attr('primitive_target', 'CPU') 334 335 def _process_vocab_cache(self, slice_mode): 336 """PS embeddingLookup cache check and process.""" 337 self.cache_enable = False 338 if self.vocab_cache_size > 0: 339 if self.target == 'CPU': 340 logger.warning("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target, " 341 "current target is CPU, so it will be ignored.") 342 return 343 enable_ps = _get_ps_context("enable_ps") 344 if not enable_ps: 345 logger.warning("The configuration of 'vocab_cache_size' is valid only in parameter server training " 346 "mode, current mode is not parameter server trainning mode, so it will be ignored.") 347 return 348 parallel_mode = _get_parallel_mode() 349 is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) 350 if is_auto_parallel: 351 rank_size = get_group_size() 352 rank_id = get_rank() 353 full_batch = _get_full_batch() 354 if rank_size > 1 and not (full_batch and slice_mode == "table_row_slice"): 355 raise ValueError(f"For '{self.cls_name}', the cache of parameter server parallel should only be " 356 f"used in \"full_batch\" and the value of \"full_batch\" should be True. " 357 f"Meanwhile, the value of 'slice_mode' should be \"table_row_slice\"." 358 f"But got full_batch: {full_batch} and 'slice_mode': \"{slice_mode}\".") 359 self.vocab_cache_size = self.vocab_cache_size * rank_size 360 _set_rank_id(rank_id) 361 self.cache_enable = True 362 if _is_role_worker(): 363 self.vocab_size = self.vocab_cache_size 364 if context.get_context("enable_sparse") != self.sparse: 365 raise ValueError(f"For '{self.cls_name}', the value of parameter 'sparse' must be same for all " 366 f"kernels and equal the value of 'enable_sparse' in context setting in " 367 f"parameter server cache mode, but got value of parameter 'sparse': {self.sparse}" 368 f" and the 'enable_sparse' in context setting: " 369 f"{context.get_context('enable_sparse')}.") 370 371 def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size): 372 """PS embeddingLookup cache enable set.""" 373 self.embedding_table.cache_enable = True 374 self.embedding_table.is_param_ps = True 375 _set_cache_enable(True) 376 if self.sparse: 377 self.forward_unique = True 378 if _is_role_worker(): 379 _insert_hash_table_size(self.embedding_table.name, vocab_cache_size, embedding_size, vocab_size) 380 381 def construct(self, indices): 382 if self.target == "CPU": 383 out = self.embeddinglookup(self.embedding_table, indices, 0) 384 else: 385 if self.forward_unique: 386 shp = self.shape(indices) + (self.embedding_size,) 387 indices_flatten = self.reshape_first(indices, (-1,)) 388 unique_id, unique_idx = self.unique(indices_flatten) 389 weight_unique = self.gatherv2(self.embedding_table, unique_id, 0) 390 weight_flatten = self.gather_revert(weight_unique, unique_idx, 0) 391 out = self.reshape(weight_flatten, shp) 392 else: 393 out = self.gatherv2(self.embedding_table, indices, 0) 394 if self.max_norm is not None: 395 axis = _make_axis_range(F.rank(indices), F.rank(out)) 396 clip_by_norm = ClipByNorm(axis) 397 out = clip_by_norm(out, self.max_norm) 398 return out 399 400 401class MultiFieldEmbeddingLookup(EmbeddingLookup): 402 r""" 403 Returns a slice of input tensor based on the specified indices and the field ids. This operation 404 supports looking up embeddings using multi hot and one hot fields simultaneously. 405 406 Note: 407 When 'target' is set to 'CPU', this module will use 408 P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which 409 specified 'offset = 0' to lookup table. 410 When 'target' is set to 'DEVICE', this module will use P.Gather() which 411 specified 'axis = 0' to lookup table. 412 The vectors with the same field_ids will be combined by the 'operator', such as 'SUM', 'MAX' and 413 'MEAN'. Ensure the input_values of the padded id is zero, so that they can be ignored. The final 414 output will be zeros if the sum of absolute weight of the field is zero. This class only 415 supports ['table_row_slice', 'batch_slice' and 'table_column_slice']. For the operation 'MAX' on 416 device Ascend, there is a constrain where batch_size * (seq_length + field_size) < 3500. 417 418 Args: 419 vocab_size (int): The size of the dictionary of embeddings. 420 embedding_size (int): The size of each embedding vector. 421 field_size (int): The field size of the final outputs. 422 param_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table. 423 Refer to class `initializer` for the values of string when a string 424 is specified. Default: 'normal'. 425 target (str): Specifies the target where the op is executed. The value must in 426 ['DEVICE', 'CPU']. Default: 'CPU'. 427 slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value must get through 428 nn.EmbeddingLookup. Default: nn.EmbeddingLookup.BATCH_SLICE. 429 feature_num_list (tuple): The accompaniment array in field slice mode. This is unused currently. Default: None. 430 max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32 431 or None. Default: None 432 sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True. 433 operator (str): The pooling method for the features in one field. Support 'SUM', 'MEAN' and 'MAX'. 434 Default: 'SUM'. 435 436 Inputs: 437 - **input_indices** (Tensor) - The shape of tensor is :math:`(batch\_size, seq\_length)`. 438 Specifies the indices of elements of the original Tensor. Input_indices must be a 2d tensor in 439 this interface. Type is Int32, Int64. 440 - **input_values** (Tensor) - The shape of tensor is :math:`(batch\_size, seq\_length)`. 441 Specifies the weights of elements of the input_indices. The lookout vector will multiply with 442 the input_values. Type is Float32. 443 - **field_ids** (Tensor) - The shape of tensor is :math:`(batch\_size, seq\_length)`. 444 Specifies the field id of elements of the input_indices. Type is Int32. 445 446 Outputs: 447 Tensor, the shape of tensor is :math:`(batch\_size, field\_size, embedding\_size)`. Type is Float32. 448 449 Raises: 450 TypeError: If `vocab_size` or `embedding_size` or `field_size` is not an int. 451 TypeError: If `sparse` is not a bool or `feature_num_list` is not a tuple. 452 ValueError: If `vocab_size` or `embedding_size` or `field_size` is less than 1. 453 ValueError: If `target` is neither 'CPU' nor 'DEVICE'. 454 ValueError: If `slice_mode` is not one of 'batch_slice', 'field_slice', 'table_row_slice', 455 'table_column_slice'. 456 ValueError: If `sparse` is False and `target` is 'CPU'. 457 ValueError: If `slice_mode` is 'field_slice' and `feature_num_list` is None. 458 ValueError: If `operator` is not one of 'SUM', 'MAX', 'MEAN'. 459 460 Supported Platforms: 461 ``Ascend`` ``GPU`` 462 463 Examples: 464 >>> input_indices = Tensor([[2, 4, 6, 0, 0], [1, 3, 5, 0, 0]], mindspore.int32) 465 >>> input_values = Tensor([[1, 1, 1, 0, 0], [1, 1, 1, 0, 0]], mindspore.float32) 466 >>> field_ids = Tensor([[0, 1, 1, 0, 0], [0, 0, 1, 0, 0]], mindspore.int32) 467 >>> net = nn.MultiFieldEmbeddingLookup(10, 2, field_size=2, operator='SUM', target='DEVICE') 468 >>> out = net(input_indices, input_values, field_ids) 469 >>> print(out.shape) 470 (2, 2, 2) 471 """ 472 OPERATOR_SUM = 'SUM' 473 OPERATOR_MEAN = 'MEAN' 474 OPERATOR_MAX = 'MAX' 475 476 def __init__(self, vocab_size, embedding_size, field_size, param_init='normal', target='CPU', 477 slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM'): 478 """Initialize MultiFieldEmbeddingLookup.""" 479 super(MultiFieldEmbeddingLookup, self).__init__(vocab_size, embedding_size, param_init, target, 480 slice_mode, feature_num_list, max_norm, sparse) 481 self.field_size = validator.check_positive_int(field_size, 'field_size', self.cls_name) 482 self.operator = operator 483 484 self.mul = P.Mul() 485 self.inf_mask_mul = P.Mul() 486 self.bias_add = P.Add() 487 self.inf_add = P.Add() 488 self.merge_op = None 489 self.count_op = P.UnsortedSegmentSum() 490 self.abs = P.Abs() 491 self.equal = P.Equal() 492 self.add = P.Add() 493 self.cast = P.Cast() 494 self.div_no_nan = P.DivNoNan() 495 self.expand = P.ExpandDims() 496 self.max_mask_mul = P.Mul() 497 self.max_no_equal = P.NotEqual() 498 499 validator.check_string(operator, ['SUM', 'MAX', 'MEAN'], 'operator', self.cls_name) 500 if operator == MultiFieldEmbeddingLookup.OPERATOR_SUM: 501 self.merge_op = P.UnsortedSegmentSum() 502 elif operator == MultiFieldEmbeddingLookup.OPERATOR_MAX: 503 self.merge_op = P.UnsortedSegmentMax() 504 else: 505 self.merge_op = P.UnsortedSegmentSum() 506 507 508 parallel_mode = _get_parallel_mode() 509 is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) 510 if slice_mode in ["table_row_slice", "batch_slice"] and is_auto_parallel: 511 self.merge_op.shard(((get_group_size(), 1, 1), (get_group_size(), 1))) 512 self.expand.shard(((get_group_size(),),)) 513 self.bias_add.shard(((1, 1), (1, 1))) 514 self.mul.shard(((get_group_size(), 1, 1), (get_group_size(), 1, 1))) 515 self.count_op.shard(((get_group_size(), 1), (get_group_size(), 1))) 516 self.add.shard(((get_group_size(),), (get_group_size(),))) 517 self.div_no_nan.shard(((get_group_size(), 1), (get_group_size(), 1))) 518 self.max_mask_mul.shard(((get_group_size(), 1), (get_group_size(), 1))) 519 self.max_no_equal.shard(((1,), ())) 520 if operator == MultiFieldEmbeddingLookup.OPERATOR_MAX: 521 self.equal.shard(((get_group_size(), 1, 1), ())) 522 self.inf_mask_mul.shard(((get_group_size(), 1, 1), ())) 523 self.merge_op.shard(((get_group_size(), 1), (get_group_size(),))) 524 self.count_op.shard(((get_group_size(),), (get_group_size(),))) 525 self.inf_add.shard(((get_group_size(), 1, 1), (get_group_size(), 1, 1))) 526 elif slice_mode == "table_column_slice" and is_auto_parallel: 527 self.merge_op.shard(((1, 1, get_group_size()), (1, 1))) 528 self.div_no_nan.shard(((1, get_group_size()), (1, 1))) 529 self.bias_add.shard(((1, 1), (1, 1))) 530 self.mul.shard(((1, 1, 1), (1, 1, get_group_size()))) 531 self.count_op.shard(((1, 1), (1, 1))) 532 self.add.shard(((1,), (1,))) 533 self.max_mask_mul.shard(((1, get_group_size()), (1, 1))) 534 self.expand.shard(((1,),)) 535 self.max_no_equal.shard(((1,), ())) 536 if operator == MultiFieldEmbeddingLookup.OPERATOR_MAX: 537 self.equal.shard(((1, 1, 1), ())) 538 self.inf_mask_mul.shard(((1, 1, 1), ())) 539 self.merge_op.shard(((1, get_group_size()), (1,))) 540 self.count_op.shard(((1,), (1,))) 541 self.inf_add.shard(((1, 1, get_group_size()), (1, 1, 1))) 542 else: 543 if is_auto_parallel: 544 raise ValueError("For '{}', the 'slice_mode' should be in ['table_row_slice', 'batch_slice' and \ 545 'table_column_slice'], but got {}".format(self.cls_name, str(slice_mode))) 546 547 # Min value for fp32 548 self.negative_inf_value = -3.402823466E+38 549 550 def construct(self, input_indices, input_values, field_ids): 551 552 _check_input_2d(F.shape(input_indices), "input_indices", self.cls_name) 553 _check_input_2d(F.shape(input_values), "input_values", self.cls_name) 554 _check_input_2d(F.shape(field_ids), "field_ids", self.cls_name) 555 _check_input_dtype(F.dtype(input_indices), "input_indices", [mstype.int32, mstype.int64], self.cls_name) 556 _check_input_dtype(F.dtype(input_values), "input_values", [mstype.float32], self.cls_name) 557 _check_input_dtype(F.dtype(field_ids), "field_ids", [mstype.int32], self.cls_name) 558 559 batch_size = self.shape(input_indices)[0] 560 num_segments = batch_size * self.field_size 561 bias = Range(0, num_segments, self.field_size)() 562 bias = self.reshape(bias, (batch_size, -1)) 563 field_ids = self.bias_add(field_ids, bias) 564 565 if self.target == "CPU": 566 out = self.embeddinglookup(self.embedding_table, input_indices, 0) 567 else: 568 if self.forward_unique: 569 shp = self.shape(input_indices) + (self.embedding_size,) 570 indices_flatten = self.reshape(input_indices, (-1,)) 571 unique_id, unique_idx = self.unique(indices_flatten) 572 weight_unique = self.gatherv2(self.embedding_table, unique_id, 0) 573 weight_flatten = self.gather_revert(weight_unique, unique_idx, 0) 574 out = self.reshape(weight_flatten, shp) 575 else: 576 out = self.gatherv2(self.embedding_table, input_indices, 0) 577 if self.max_norm is not None: 578 axis = _make_axis_range(F.rank(input_indices), F.rank(out)) 579 clip_by_norm = ClipByNorm(axis) 580 out = clip_by_norm(out, self.max_norm) 581 582 weights = self.reshape(input_values, (batch_size, self.shape(input_indices)[1], 1)) 583 embedding = self.mul(weights, out) 584 585 if self.operator == 'MAX': 586 # Fill the padding value to -inf, so the padded value will not influence the results 587 negative_inf_mask = self.cast(self.equal(weights, 0), mstype.float32) 588 inf_mask = self.inf_mask_mul(negative_inf_mask, self.negative_inf_value) 589 embedding = self.inf_add(embedding, inf_mask) 590 embedding = self.reshape(embedding, (-1, self.embedding_size)) 591 field_ids = self.reshape(field_ids, (-1,)) 592 593 merged_vectors = self.merge_op(embedding, field_ids, num_segments) 594 595 if self.operator == 'MAX': 596 value_count = self.count_op(self.abs(self.reshape(input_values, (-1,))), field_ids, num_segments) 597 value_zeros = self.cast(self.max_no_equal(value_count, 0.0), mstype.float32) 598 count = self.expand(value_zeros, -1) 599 merged_vectors = self.max_mask_mul(merged_vectors, count) 600 601 if self.operator == 'MEAN': 602 value_count = self.count_op(self.abs(input_values), field_ids, num_segments) 603 value_count = self.expand(value_count, -1) 604 merged_vectors = self.div_no_nan(merged_vectors, value_count) 605 606 merged_vectors = self.reshape(merged_vectors, (batch_size, self.field_size, -1)) 607 return merged_vectors 608