• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""layers for second order optimization"""
16from __future__ import absolute_import
17
18import numpy as np
19
20import mindspore.common.dtype as mstype
21import mindspore.log as logger
22from mindspore.common.tensor import Tensor
23from mindspore.common.initializer import initializer, Initializer
24from mindspore.communication.management import get_group_size, get_rank
25from mindspore.ops import operations as P
26from mindspore.ops.operations._thor_ops import ThorIm2Col
27from mindspore.common.parameter import Parameter
28from mindspore import _checkparam as Validator
29from mindspore._checkparam import twice
30from mindspore import context
31from mindspore.nn.cell import Cell
32from mindspore.nn.layer.activation import get_activation
33from mindspore.parallel._ps_context import _is_role_worker, _get_ps_context, \
34    _set_rank_id, _insert_hash_table_size, _set_cache_enable
35from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch
36from mindspore.context import ParallelMode
37from mindspore.ops import functional as F
38from mindspore.nn.layer.basic import ClipByNorm
39from mindspore.ops.primitive import constexpr
40
41__all__ = ['DenseThor', 'Conv2dThor', 'EmbeddingThor', 'EmbeddingLookupThor']
42
43
44class DenseThor(Cell):
45    r"""
46    The dense connected layer and saving the information needed for THOR.
47
48    Applies dense connected layer for the input and saves the information A and G in the dense connected layer
49    needed for THOR.
50
51    This layer implements the operation as:
52
53    .. math::
54        \text{outputs} = \text{activation}(\text{inputs} * \text{kernel} + \text{bias}),
55
56    where :math:`\text{activation}` is the activation function , :math:`\text{kernel}` is a weight matrix with the same
57    data type as the inputs created by the layer, and :math:`\text{bias}` is a bias vector
58    with the same data type as the inputs created by the layer (only if has_bias is ``True`` ).
59
60    Args:
61        in_channels (int): The number of the input channels.
62        out_channels (int): The number of the output channels.
63        weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
64            is same as `x`. The values of str refer to the function `initializer`. Default: ``'normal'`` .
65        bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
66            same as `x`. The values of str refer to the function `initializer`. Default: ``'zeros'`` .
67        has_bias (bool): Specifies whether the layer uses a bias vector. Default: ``True`` .
68        activation (str): activate function applied to the output of the fully connected layer, eg. 'ReLU'.
69            Default: ``None`` .
70
71    Inputs:
72        - **x** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
73
74    Outputs:
75        Tensor of shape :math:`(N, out\_channels)`.
76
77    Raises:
78        ValueError: If the shape of `weight_init` or `bias_init` is incorrect.
79
80    Supported Platforms:
81        ``Ascend`` ``GPU``
82
83    Examples:
84        >>> import mindspore as ms
85        >>> import numpy as np
86        >>> x = ms.Tensor(np.array([[1, 2, 3], [3, 4, 5]]), ms.float32)
87        >>> net = ms.nn.DenseThor(3, 4, weight_init="ones")
88        >>> output = net(x)
89        >>> print(output)
90        [[  6.  6.  6.  6.]
91         [ 12. 12. 12. 12. ]]
92    """
93
94    def __init__(self,
95                 in_channels,
96                 out_channels,
97                 weight_init='normal',
98                 bias_init='zeros',
99                 has_bias=True,
100                 activation=None):
101        """Initialize DenseThor."""
102        super(DenseThor, self).__init__()
103        self.thor = True
104        self.in_channels = Validator.check_positive_int(in_channels, "in_channels", self.cls_name)
105        self.out_channels = Validator.check_positive_int(out_channels, "out_channels", self.cls_name)
106        self.has_bias = Validator.check_bool(has_bias, "has_bias", self.cls_name)
107        if isinstance(weight_init, Tensor):
108            if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
109                    weight_init.shape[1] != in_channels:
110                raise ValueError(f"For '{self.cls_name}', weight init shape error. The dim of 'weight_init' should "
111                                 f"be equal to 2, and the first dim must be equal to 'out_channels', and the "
112                                 f"second dim must be equal to 'in_channels'. But got 'weight_init': {weight_init}, "
113                                 f"'out_channels': {out_channels}, 'in_channels': {in_channels}.")
114        self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
115        self.bias = None
116        if self.has_bias:
117            if isinstance(bias_init, Tensor):
118                if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
119                    raise ValueError(f"For '{self.cls_name}', bias init shape error. The dim of 'bias_init' should "
120                                     f"be equal to 1, and the first dim must be equal to 'out_channels'. But got "
121                                     f"'bias_init': {bias_init}, 'out_channels': {out_channels}.")
122            self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
123            self.bias_add = P.BiasAdd()
124
125        self.matmul = P.MatMul(transpose_b=True)
126        self.activation = get_activation(activation)
127        self.activation_flag = self.activation is not None
128
129        self.matrix_a = Parameter(Tensor(np.eye(in_channels).astype(np.float32)),
130                                  name='matrix_a', requires_grad=False)
131        self.matrix_g = Parameter(Tensor(np.eye(out_channels).astype(np.float32)),
132                                  name="matrix_g", requires_grad=False)
133        self.shape = P.Shape()
134        self.reshape = P.Reshape()
135        self.transpose = P.Transpose()
136        self.mul = P.Mul()
137        self.is_ascend = True
138        self.split_dim = 128
139        if context.get_context("device_target") == "Ascend":
140            self._process_ascend_dense_thor(out_channels, in_channels)
141        else:
142            self.is_ascend = False
143            self.cube_matmul = P.MatMul(transpose_a=True)
144        self.getG = P.InsertGradientOf(self.save_gradient)
145
146    def _process_ascend_dense_thor(self, out_channels, in_channels):
147        """process ascend dense thor"""
148        self.matmul = P.MatMul(transpose_b=True)
149        self.cube_matmul = P.CusMatMulCube(transpose_a=True)
150        self.cast = P.Cast()
151        self.is_nsp_layer = (out_channels == 2)
152
153    def save_gradient(self, dout):
154        """
155           this function only for thor optimizer
156           save_gradient
157        """
158        out = dout
159        if self.is_ascend:
160            if not self.is_nsp_layer:
161                shape = self.shape(dout)
162                normalizer = self.cast(shape[0], mstype.float32)
163                matrix_g = self.cube_matmul(dout, dout)
164                matrix_g = self.mul(matrix_g, 1.0 / normalizer)
165                self.matrix_g = matrix_g
166        else:
167            dout_shape = self.shape(dout)
168            normalizer = dout_shape[0]
169            matrix_g = self.cube_matmul(dout, dout)
170            matrix_g = self.mul(matrix_g, 1.0 / normalizer)
171            self.matrix_g = matrix_g
172        return out
173
174    def construct(self, x):
175        if self.thor:
176            if self.is_ascend:
177                inputs = self.cube_matmul(x, x)
178                shape = self.shape(x)
179                normalizer = self.cast(shape[0], mstype.float32)
180                matrix_a = self.mul(inputs, 1.0 / normalizer)
181                self.matrix_a = matrix_a
182            else:
183                inputs = self.cube_matmul(x, x)
184                inputs_shape = self.shape(inputs)
185                normalizer = inputs_shape[0]
186                matrix_a = self.mul(inputs, 1.0 / normalizer)
187                self.matrix_a = matrix_a
188            x = self.matmul(x, self.weight)
189            x = self.getG(x)
190        else:
191            x = self.matmul(x, self.weight)
192        if self.has_bias:
193            x = self.bias_add(x, self.bias)
194        if self.activation_flag:
195            x = self.activation(x)
196        # We use Depend to make 'self.matrix_g' as primal graph's weight parameter,
197        # for it's used in 'save_gradient' gradient procedure.
198        return F.depend(x, self.matrix_g)
199
200    def extend_repr(self):
201        s = 'input_channels={}, output_channels={}'.format(self.in_channels, self.out_channels)
202        if self.has_bias:
203            s += ', has_bias={}'.format(self.has_bias)
204        return s
205
206
207class _ConvThor(Cell):
208    """
209    Applies a N-D convolution over an input signal composed of multiple input planes.
210    """
211
212    def __init__(self, in_channels, out_channels, kernel_size, stride, pad_mode,
213                 padding, dilation, group, has_bias, weight_init, bias_init, transposed=False):
214        """Initialize _ConvThor."""
215        super(_ConvThor, self).__init__()
216        self.in_channels = Validator.check_positive_int(in_channels, "in_channels", self.cls_name)
217        self.out_channels = Validator.check_positive_int(out_channels, "out_channels", self.cls_name)
218        self.kernel_size = kernel_size
219        self.stride = stride
220        self.pad_mode = pad_mode
221        self.bias_init = bias_init
222        if isinstance(padding, tuple):
223            for pad in padding:
224                Validator.check_non_negative_int(pad, 'padding item', self.cls_name)
225            self.padding = padding
226        elif isinstance(padding, int):
227            Validator.check_non_negative_int(padding, 'padding', self.cls_name)
228            self.padding = padding
229        else:
230            raise TypeError(f"For '{self.cls_name}', the type of 'padding' must be int/tuple(int), but got "
231                            f"{type(padding).__name__}.")
232
233        self.dilation = dilation
234        self.group = Validator.check_positive_int(group, "group", self.cls_name)
235        self.has_bias = has_bias
236        self.__validate_kernel_size(kernel_size)
237        self.__validate_stride(stride)
238        self.__validate_dilation(dilation)
239        if in_channels % group != 0:
240            raise ValueError(f"For '{self.cls_name}', the 'in_channels' must be divisible by 'group', but got "
241                             f"'in_channels': {in_channels} and 'group': {group}.")
242        if out_channels % group != 0:
243            raise ValueError(f"For '{self.cls_name}', the 'out_channels' must be divisible by 'group', but got "
244                             f"'out_channels': {out_channels} and 'group': {group}.")
245        if not transposed:
246            shape = [out_channels, in_channels // group, *kernel_size]
247        else:
248            shape = [in_channels, out_channels // group, *kernel_size]
249        self.weight = Parameter(initializer(weight_init, shape), name='weight')
250
251        if Validator.check_bool(has_bias, "has_bias", self.cls_name):
252            self.bias = Parameter(initializer(self.bias_init, [out_channels]), name='bias')
253        else:
254            if self.bias_init != 'zeros':
255                logger.warning("Value of 'has_bias' is False, value of 'bias_init' will be ignored.")
256            self.bias = None
257
258    def __validate_kernel_size(self, kernel_size):
259        """validate kernel size."""
260        if (not isinstance(kernel_size[0], int)) or (not isinstance(kernel_size[1], int)) or \
261                isinstance(kernel_size[0], bool) or isinstance(kernel_size[1], bool) or \
262                kernel_size[0] < 1 or kernel_size[1] < 1:
263            raise ValueError(f"For '{self.cls_name}', all elements in 'kernel_size' must be int or tuple and "
264                             f"equal to or greater than 1, but got 'kernel_size': {kernel_size}.")
265
266    def __validate_stride(self, stride):
267        """validate stride."""
268        if (not isinstance(stride[0], int)) or (not isinstance(stride[1], int)) or \
269                isinstance(stride[0], bool) or isinstance(stride[1], bool) or stride[0] < 1 or stride[1] < 1:
270            raise ValueError(f"For '{self.cls_name}', all elements in 'stride' must be int or tuple and "
271                             f"equal to or greater than 1, but got 'stride': {stride}.")
272
273    def __validate_dilation(self, dilation):
274        """validate dilation."""
275        if (not isinstance(dilation[0], int)) or (not isinstance(dilation[1], int)) or \
276                isinstance(dilation[0], bool) or isinstance(dilation[1], bool) or dilation[0] < 1 or dilation[1] < 1:
277            raise ValueError(f"For '{self.cls_name}', all elements in 'dilation' must be int or tuple and "
278                             f"equal to or greater than 1, but got 'dilation': {dilation}.")
279
280
281class Conv2dThor(_ConvThor):
282    r"""
283    2D convolution layer and saving the information needed for THOR.
284
285
286    Applies a 2D convolution over an input tensor which is typically of shape :math:`(N, C_{in}, H_{in}, W_{in})`,
287    where :math:`N` is batch size, :math:`C_{in}` is channel number, and :math:`H_{in}, W_{in})` are height and width.
288    And saves the information A and G in the 2D convolution layer needed for THOR.
289
290    For each batch of shape :math:`(C_{in}, H_{in}, W_{in})`, the formula is defined as:
291
292
293    .. math::
294
295        out_j = \sum_{i=0}^{C_{in} - 1} ccor(W_{ij}, X_i) + b_j,
296
297    where :math:`ccor` is the cross-correlation operator, :math:`C_{in}` is the input channel number, :math:`j` ranges
298    from :math:`0` to :math:`C_{out} - 1`, :math:`W_{ij}` corresponds to the :math:`i`-th channel of the :math:`j`-th
299    filter and :math:`out_{j}` corresponds to the :math:`j`-th channel of the output. :math:`W_{ij}` is a slice
300    of kernel and it has shape :math:`(\text{ks_h}, \text{ks_w})`, where :math:`\text{ks_h}` and
301    :math:`\text{ks_w}` are the height and width of the convolution kernel. The full kernel has shape
302    :math:`(C_{out}, C_{in} // \text{group}, \text{ks_h}, \text{ks_w})`, where group is the group number
303    to split the input `x` in the channel dimension.
304
305    If the 'pad_mode' is set to be "valid", the output height and width will be
306    :math:`\left \lfloor{1 + \frac{H_{in} + 2 \times \text{padding} - \text{ks_h} -
307    (\text{ks_h} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` and
308    :math:`\left \lfloor{1 + \frac{W_{in} + 2 \times \text{padding} - \text{ks_w} -
309    (\text{ks_w} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` respectively.
310
311    Note:
312        For Ascend, the type of inputs should be subclass of Tensor[Float16], Tensor[Int8].
313        For GPU, the type of inputs should be subclass of Tensor[Float32].
314
315    Args:
316        in_channels (int): The number of the input channel :math:`C_{in}`.
317        out_channels (int): The number of the output channel :math:`C_{out}`.
318        kernel_size (Union[int, tuple[int]]): The data type is int or a tuple of 2 integers. Specifies the height
319            and width of the 2D convolution window. Single int means that the value is not only the height, but also
320            the width of the kernel. A tuple of 2 integers means the height and the width of the kernel respectively.
321        stride (Union[int, tuple[int]]): The distance of kernel moving, an int number represents the height and width
322             of movement, or a tuple of two int numbers that represent height and width of movement, respectively.
323             Default: ``1`` .
324        pad_mode (str): Specifies padding mode. The optional values are
325            ``"same"`` , ``"valid"`` , ``"pad"`` . Default: ``"same"`` .
326
327            - ``"same"``: Adopts the way of completion. The shape of the output will be the same as
328              the `x`. The total number of padding will be calculated in horizontal and vertical
329              directions and evenly distributed to top and bottom, left and right if possible. Otherwise, the
330              last extra padding will be done from the bottom and the right side. If this mode is set, `padding`
331              must be 0.
332
333            - ``"valid"``: Adopts the way of discarding. The possible largest height and width of output will be
334              returned without padding. Extra pixels will be discarded. If this mode is set, `padding` must be 0.
335
336            - ``"pad"``: Implicit paddings on both sides of the input `x`. The number of `padding` will be padded to
337              the input Tensor borders. `padding` must be greater than or equal to 0.
338
339        padding (Union[int, tuple[int]]): Implicit paddings on both sides of the input `x`. If `padding` is an integer,
340                    the paddings of top, bottom, left and right are the same, equal to padding. If `padding` is a tuple
341                    with four integers, the paddings of top, bottom, left and right will be equal to padding[0],
342                    padding[1], padding[2], and padding[3] accordingly. Default: ``0`` .
343        dilation (Union[int, tuple[int]]): The data type is int or a tuple of 2 integers. Specifies the dilation rate
344                                      to use for dilated convolution. If set to be :math:`k > 1`, there will
345                                      be :math:`k - 1` pixels skipped for each sampling location. Its value must
346                                      be greater or equal to 1 and bounded by the height and width of the  input `x`.
347                                      Default: ``1`` .
348        group (int): Splits filter into groups, `in_ channels` and `out_channels` must be
349            divisible by the number of groups. If the group is equal to `in_channels` and `out_channels`,
350            this 2D convolution layer also can be called 2D depthwise convolution layer. Default: ``1`` .
351        has_bias (bool): Specifies whether the layer uses a bias vector. Default: ``False`` .
352        weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializes the convolution kernel.
353            It can be a Tensor, a string, an Initializer or a number. When a string is specified,
354            values from ``'TruncatedNormal'`` , ``'Normal'`` , ``'Uniform'`` , ``'HeUniform'`` and ``'XavierUniform'``
355            distributions as well as constant ``'One'`` and ``'Zero'`` distributions are possible. Alias
356            ``'xavier_uniform'`` , ``'he_uniform'`` , ``'ones'`` and ``'zeros'`` are acceptable. Uppercase and
357            lowercase are both acceptable. Refer to the values of Initializer for more details. Default: ``'normal'`` .
358        bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializes the bias vector. Possible
359            Initializer and string are the same as 'weight_init'. Refer to the values of
360            Initializer for more details. Default: ``'zeros'`` .
361
362    Inputs:
363        - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
364
365    Outputs:
366        Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
367
368    Supported Platforms:
369        ``Ascend`` ``GPU``
370
371    Examples:
372        >>> import mindspore as ms
373        >>> import numpy as np
374        >>> net = ms.nn.Conv2dThor(120, 240, 4, has_bias=False, weight_init='normal')
375        >>> # for Ascend
376        >>> x = ms.Tensor(np.ones([1, 120, 1024, 640]), ms.float16)
377        >>> print(net(x).shape)
378        (1, 240, 1024, 640)
379    """
380
381    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
382                 pad_mode='same', padding=0, dilation=1, group=1, has_bias=False,
383                 weight_init='normal', bias_init='zeros'):
384        """Initialize Conv2dThor."""
385        kernel_size = twice(kernel_size)
386        stride = twice(stride)
387        self._dilation = dilation
388        dilation = twice(dilation)
389        super(Conv2dThor, self).__init__(in_channels, out_channels, kernel_size,
390                                         stride, pad_mode, padding, dilation, group, has_bias, weight_init, bias_init)
391        self.conv2d = P.Conv2D(out_channel=self.out_channels, kernel_size=self.kernel_size,
392                               mode=1, pad_mode=self.pad_mode, pad=self.padding,
393                               stride=self.stride, dilation=self.dilation, group=self.group)
394        self._init_depthwise_conv2d(weight_init)
395        self.bias_add = P.BiasAdd()
396        self.thor = True
397        self.hw = kernel_size[0] * kernel_size[1]
398        self.matrix_a_dim = self.in_channels * self.kernel_size[0] * self.kernel_size[1]
399        self.matrix_g_dim = self.out_channels
400        self.shape = P.Shape()
401        self.reshape = P.Reshape()
402        self.mul = P.Mul()
403        self.cast = P.Cast()
404        self.a_normalizer = Parameter(initializer(1, [1], mstype.float32), name="a_normalizer", requires_grad=False)
405        self.g_normalizer = Parameter(initializer(1, [1], mstype.float32), name="g_normalizer", requires_grad=False)
406        self.is_ascend = True
407        if context.get_context("device_target") == "Ascend":
408            self._process_ascend_conv2d_thor(kernel_size, stride)
409        else:
410            self.is_ascend = False
411            self.img2col = ThorIm2Col(kernel_size=kernel_size, stride=stride, pad_mode="same")
412            self.matmul = P.MatMul(transpose_b=True)
413            self.reduce_mean = P.ReduceMean(keep_dims=False)
414            self.matrix_a_cov = Parameter(Tensor(np.zeros([self.matrix_a_dim, self.matrix_a_dim]).astype(np.float32)),
415                                          name='matrix_a', requires_grad=False)
416            self.matrix_g_cov = Parameter(Tensor(np.zeros([self.matrix_g_dim, self.matrix_g_dim]).astype(np.float32)),
417                                          name='matrix_g', requires_grad=False)
418        self.getG = P.InsertGradientOf(self.save_gradient)
419
420    def _process_ascend_conv2d_thor(self, kernel_size, stride):
421        """process ascend conv2d thor"""
422        ksizes = (1, kernel_size[0], kernel_size[1], 1)
423        strides = (1, stride[0], stride[1], 1)
424        ksizes_tbe = (kernel_size[0], kernel_size[1])
425        self.img2col = P.CusImg2Col(ksizes=ksizes, strides=strides)
426        self.transpose = P.Transpose()
427        self.reshape = P.Reshape()
428        self.cube_matmul = P.CusMatMulCube(transpose_a=True)
429        self.diag_block_dim = 128
430        self.matrix_a_cov = Parameter(Tensor(np.eye(self.matrix_a_dim).astype(np.float32)),
431                                      name='matrix_a', requires_grad=False)
432        self.matrix_g_cov = Parameter(Tensor(np.eye(self.matrix_g_dim).astype(np.float32)),
433                                      name='matrix_g', requires_grad=False)
434        self.slice = P.Slice()
435        self.im2col = P.NewIm2Col(ksizes=ksizes_tbe, strides=stride[0], padding_mode="SAME")
436
437    def _init_depthwise_conv2d(self, weight_init):
438        """Initialize depthwise conv2d op"""
439        if context.get_context("device_target") == "Ascend" and self.group > 1:
440            self.dilation = self._dilation
441            Validator.check_int('group', self.group, self.in_channels, Validator.EQ, self.cls_name)
442            Validator.check_int('group', self.group, self.out_channels, Validator.EQ, self.cls_name)
443            self.conv2d = P.DepthwiseConv2dNative(channel_multiplier=1,
444                                                  kernel_size=self.kernel_size,
445                                                  pad_mode=self.pad_mode,
446                                                  pad=self.padding,
447                                                  stride=self.stride,
448                                                  dilation=self.dilation)
449            weight_shape = [1, self.in_channels, *self.kernel_size]
450            self.weight_init = weight_init
451            if isinstance(weight_init, Tensor):
452                self.weight_init = weight_init.swapaxes(0, 1)
453            if isinstance(weight_init, Initializer):
454                self.weight_init.shape = weight_shape
455            self.weight = Parameter(initializer(self.weight_init, weight_shape), name='weight')
456
457    def save_gradient(self, dout):
458        """save_gradient"""
459        out = dout
460        if self.is_ascend:
461            dout_shape = self.shape(dout)
462            dout = self.transpose(dout, (0, 2, 3, 1))
463            dout = self.reshape(dout, (-1, dout_shape[1]))
464            dout_shape = self.shape(dout)
465            normalizer = dout_shape[0]
466            matrix_g = self.cube_matmul(dout, dout)
467            normalizer = self.cast(normalizer, mstype.float32)
468            matrix_g = self.mul(matrix_g, 1.0 / normalizer)
469            self.g_normalizer = self.reshape(Tensor(normalizer), (1,))
470            self.matrix_g_cov = matrix_g
471        else:
472            dout = self.reduce_mean(dout, 0)
473            dout_shape = self.shape(dout)
474            dout = self.reshape(dout, (dout_shape[0], -1))
475            dout_shape = self.shape(dout)
476            normalizer = dout_shape[1]
477            dout = self.cast(dout, mstype.float32)
478            matrix_g = self.matmul(dout, dout)
479            matrix_g = self.mul(matrix_g, 1.0 / normalizer)
480            self.g_normalizer = self.reshape(Tensor(normalizer), (1,))
481            self.matrix_g_cov = matrix_g
482        return out
483
484    def construct(self, x):
485        if self.thor:
486            if self.is_ascend:
487                matrix_a = self.im2col(x)
488                matrix_a_shape = self.shape(matrix_a)
489                y = matrix_a_shape[3]
490                matrix_a = self.reshape(matrix_a, (-1, y))
491                matrix_a_shape = self.shape(matrix_a)
492                normalizer = matrix_a_shape[0]
493                matrix_a = self.cube_matmul(matrix_a, matrix_a)
494                normalizer = self.cast(normalizer, mstype.float32)
495                matrix_a = self.mul(matrix_a, 1.0 / normalizer)
496                self.a_normalizer = self.reshape(Tensor(normalizer), (1,))
497                self.matrix_a_cov = matrix_a
498                weight = self.cast(self.weight, mstype.float16)
499                output = self.conv2d(x, weight)
500                output = self.getG(output)
501            else:
502                matrix_a = self.img2col(x)
503                matrix_a_shape = self.shape(matrix_a)
504                matrix_a = self.reshape(matrix_a, (matrix_a_shape[0] * matrix_a_shape[1] * matrix_a_shape[2],
505                                                   matrix_a_shape[3], -1))
506                matrix_a = self.reduce_mean(matrix_a, 1)
507                matrix_a_shape = self.shape(matrix_a)
508                normalizer = matrix_a_shape[1]
509                matrix_a = self.cast(matrix_a, mstype.float32)
510                matrix_a = self.matmul(matrix_a, matrix_a)
511                matrix_a = self.mul(matrix_a, 1.0 / normalizer)
512                self.a_normalizer = self.reshape(Tensor(normalizer), (1,))
513                self.matrix_a_cov = matrix_a
514                output = self.conv2d(x, self.weight)
515                output = self.getG(output)
516        else:
517            if self.is_ascend:
518                weight = self.cast(self.weight, mstype.float16)
519                output = self.conv2d(x, weight)
520            else:
521                output = self.conv2d(x, self.weight)
522        if self.has_bias:
523            if self.is_ascend:
524                bias = self.cast(self.bias, mstype.float16)
525                output = self.bias_add(output, bias)
526            else:
527                output = self.bias_add(output, self.bias)
528        return output
529
530    def extend_repr(self):
531        s = 'input_channels={}, output_channels={}, kernel_size={}, stride={}, ' \
532            'pad_mode={}, padding={}, dilation={}, group={}, has_bias={}, ' \
533            'bias_init={}'.format(self.in_channels, self.out_channels, self.kernel_size,
534                                  self.stride, self.pad_mode, self.padding, self.dilation,
535                                  self.group, self.has_bias, self.bias_init)
536        return s
537
538
539class EmbeddingThor(Cell):
540    r"""
541    A simple lookup table that stores embeddings of a fixed dictionary and size
542    and saving the information needed for THOR.
543
544    This module is often used to store word embeddings and retrieve them using
545    indices. The input to the module is a list of indices, and the output is
546    the corresponding word embeddings. And saves the information A and G in the dense connected layer
547    needed for THOR.
548
549    Note:
550        When 'use_one_hot' is set to True, the type of the input `x` must be mindspore.int32.
551
552    Args:
553        vocab_size (int): The size of the dictionary of embeddings.
554        embedding_size (int): The size of each embedding vector.
555        use_one_hot (bool): Specifies whether to apply one_hot encoding form. Default: ``False`` .
556        embedding_table (Union[Tensor, str, Initializer, numbers.Number]): Initializes the embedding_table.
557            Refer to class `initializer` for the values of string when a string is specified. Default: ``'normal'`` .
558        dtype (:class:`mindspore.dtype`): Data type of input `x`. Default: ``mindspore.float32`` .
559        padding_idx (int, None): When the padding_idx encounters index, the output embedding vector of this index
560                                 will be initialized to zero. Default: ``None`` . The feature is inactivated.
561    Inputs:
562        - **x** (Tensor) - Tensor of input shape :math:`(\text{batch_size}, \text{x_length})`. The elements of
563          the Tensor must be integer and not larger than vocab_size. Otherwise the corresponding embedding vector will
564          be zero.
565
566    Outputs:
567        Tensor of output shape :math:`(\text{batch_size}, \text{x_length}, \text{embedding_size})`.
568
569    Supported Platforms:
570        ``Ascend`` ``GPU``
571
572    Examples:
573        >>> import mindspore as ms
574        >>> import numpy as np
575        >>> net = ms.nn.EmbeddingThor(20000, 768,  True)
576        >>> x = ms.Tensor(np.ones([8, 128]), ms.int32)
577        >>>
578        >>> # Maps the input word IDs to word embedding.
579        >>> output = net(x)
580        >>> output.shape
581        (8, 128, 768)
582    """
583
584    def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal',
585                 dtype=mstype.float32, padding_idx=None):
586        """Initialize EmbeddingThor."""
587        super(EmbeddingThor, self).__init__()
588        self.vocab_size = Validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name)
589        self.embedding_size = Validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name)
590        Validator.check_value_type('use_one_hot', use_one_hot, [bool], self.cls_name)
591        Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
592        self.use_one_hot = use_one_hot
593        self.dtype = dtype
594        self.init_tensor = initializer(embedding_table, [vocab_size, embedding_size])
595        self.padding_idx = padding_idx
596        if padding_idx is not None:
597            self.padding_idx = Validator.check_int_range(padding_idx, 0, vocab_size, Validator.INC_BOTH,
598                                                         "padding_idx", self.cls_name)
599            self.init_tensor[self.padding_idx] = 0
600        self.embedding_table = Parameter(self.init_tensor, name='embedding_table')
601        self.expand = P.ExpandDims()
602        self.reshape_flat = P.Reshape()
603        self.shp_flat = (-1,)
604        self.gather = P.Gather()
605        self.one_hot = P.OneHot()
606        self.on_value = Tensor(1.0, self.dtype)
607        self.off_value = Tensor(0.0, self.dtype)
608        self.array_mul = P.MatMul()
609        self.reshape = P.Reshape()
610        self.get_shp = P.Shape()
611        self.thor = True
612        self.matrix_a = Parameter(Tensor(np.zeros([vocab_size]).astype(np.float32)),
613                                  name='matrix_a', requires_grad=False)
614        self.matrix_g = Parameter(Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float32)),
615                                  name="matrix_g", requires_grad=False)
616        self.reduce_sum = P.ReduceSum(keep_dims=False)
617        self.getG = P.InsertGradientOf(self.save_gradient)
618        self.cast = P.Cast()
619        if context.get_context("device_target") == "Ascend":
620            self.cube_matmul = P.CusMatMulCube(transpose_a=True)
621        else:
622            self.cube_matmul = P.MatMul(transpose_a=True)
623        self.mul = P.Mul()
624
625    def save_gradient(self, dout):
626        """
627           this function only for thor optimizer
628           save_gradient
629        """
630        out = dout
631        shape = self.get_shp(dout)
632        normalizer = self.cast(shape[0], mstype.float32)
633        matrix_g = self.cube_matmul(dout, dout)
634        matrix_g = self.mul(matrix_g, 1.0 / normalizer)
635        self.matrix_g = matrix_g
636        return out
637
638    def construct(self, ids):
639        extended_ids = self.expand(ids, -1)
640        out_shape = self.get_shp(ids) + (self.embedding_size,)
641        flat_ids = self.reshape_flat(extended_ids, self.shp_flat)
642
643        if self.use_one_hot:
644            one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
645            output_for_reshape = self.array_mul(one_hot_ids, self.embedding_table)
646        else:
647            if self.thor:
648                one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
649                matrix_a = self.reduce_sum(one_hot_ids, 0)
650                self.matrix_a = matrix_a
651                output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
652                output_for_reshape = self.getG(output_for_reshape)
653            else:
654                output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
655
656        output = self.reshape(output_for_reshape, out_shape)
657        # We use Depend to make 'self.matrix_g' as primal graph's weight parameter,
658        # for it's used in 'save_gradient' gradient procedure.
659        return F.depend(output, self.matrix_g)
660
661    def extend_repr(self):
662        s = 'vocab_size={}, embedding_size={}, use_one_hot={}, embedding_table={}, dtype={}, padding_idx={}'.format(
663            self.vocab_size, self.embedding_size, self.use_one_hot, self.embedding_table, self.dtype, self.padding_idx)
664        return s
665
666
667@constexpr
668def _make_axis_range(start, end):
669    axis = tuple(range(start, end))
670    return axis
671
672
673class EmbeddingLookupThor(Cell):
674    r"""
675    Returns a slice of the input tensor based on the specified indices
676    and saving the information needed for THOR.
677
678    This module has the same function as EmbeddingLookup, but additionally saves the information A and G in the
679    embeddinglookup layer needed for THOR.
680
681
682    Args:
683        vocab_size (int): The size of the dictionary of embeddings.
684        embedding_size (int): The size of each embedding vector.
685        param_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table.
686            Refer to class `initializer` for the values of string when a string is specified.
687            Default: ``'normal'`` .
688        target (str): Specifies the target where the op is executed. The value must in
689            [ ``'DEVICE'`` , ``'CPU'`` ]. Default: ``'CPU'`` .
690        slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value must get through
691            nn.EmbeddingLookup. Default: nn.EmbeddingLookup.BATCH_SLICE.
692        manual_shapes (tuple): The accompaniment array in field slice mode.
693        max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32 or None.
694                                       Default: ``None`` .
695        sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be ``true`` .
696                       Default: ``True`` .
697        vocab_cache_size (int): Cache size of the dictionary of embeddings. Default: ``0`` . It is valid only in
698            'DEVICE' target. And the moment parameter of corresponding optimizer will also be set to the cache size.
699            In addition, it should be noted that it will cost the 'DEVICE' memory, so suggests setting a reasonable
700            value to avoid insufficient memory.
701
702    Inputs:
703        - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
704
705    Outputs:
706        Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
707
708    Raises:
709        ValueError: If `target` is neither 'CPU' nor 'DEVICE'.
710        ValueError: If `slice_mode` is not one of 'batch_slice' or 'field_slice' or
711                    'table_row_slice' or 'table_column_slice'.
712        ValueError: If `sparse` is False and `target` is 'CPU'.
713        ValueError: If `slice_mode` is 'field_slice' and `manual_shapes` is None.
714        TypeError: If `vocab_size` or `embedding_size` or `vocab_cache_size` is not an int.
715        TypeError: If `sparse` is not a bool or `manual_shapes` is not a tuple.
716        ValueError: If `vocab_size` or `embedding_size` is less than 1.
717        ValueError: If `vocab_cache_size` is less than 0.
718
719
720    Supported Platforms:
721        ``Ascend``
722
723    Examples:
724        >>> import mindspore as ms
725        >>> import numpy as np
726        >>> input_indices = ms.Tensor(np.array([[1, 0], [3, 2]]), ms.int32)
727        >>> result = ms.nn.EmbeddingLookup(4,2)(input_indices)
728        >>> print(result.shape)
729        (2, 2, 2)
730    """
731    BATCH_SLICE = "batch_slice"
732    FIELD_SLICE = "field_slice"
733    TABLE_ROW_SLICE = "table_row_slice"
734    TABLE_COLUMN_SLICE = "table_column_slice"
735
736    def __init__(self, vocab_size, embedding_size, param_init='normal',
737                 target='CPU', slice_mode='batch_slice', manual_shapes=None,
738                 max_norm=None, sparse=True, vocab_cache_size=0):
739        super(EmbeddingLookupThor, self).__init__()
740        Validator.check_value_type('sparse', sparse, [bool], self.cls_name)
741        self.vocab_size = Validator.check_positive_int(vocab_size, 'vocab_size', self.cls_name)
742        self.vocab_cache_size = Validator.check_non_negative_int(vocab_cache_size, 'vocab_cache_size', self.cls_name)
743        self.target = target
744        self.sparse = sparse
745        self.cache_enable = self.vocab_cache_size > 0
746        self.forward_unique = False
747        self.dtype = mstype.float16
748        if target not in ('CPU', 'DEVICE'):
749            raise ValueError(f"For '{self.cls_name}', the 'target' must be one of values in ('CPU', 'DEVICE'), "
750                             f"but got {target}.")
751        if not sparse and target == 'CPU':
752            raise ValueError(f"For '{self.cls_name}', embedding_lookup must be sparse when 'target' is CPU, but got "
753                             f"'sparse': {sparse}, 'target': {target}.")
754        if sparse:
755            self.gatherv2 = P.SparseGatherV2()
756        else:
757            self.gatherv2 = P.Gather()
758        self.embeddinglookup = P.EmbeddingLookup().set_device('CPU')
759        enable_ps = _get_ps_context("enable_ps")
760        if enable_ps:
761            self._process_vocab_cache(slice_mode)
762        self.embedding_size = Validator.check_positive_int(embedding_size, 'embedding_size', self.cls_name)
763        self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size],
764                                                     mstype.float16), name='embedding_table')
765        parallel_mode = _get_parallel_mode()
766        is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
767        self.gather_revert = P.Gather()
768        self.reshape_first = P.Reshape()
769        self.reshape = P.Reshape()
770        self.unique = P.Unique()
771        self.shape = P.Shape()
772        if is_auto_parallel:
773            self.unique = P.Unique().shard(((1,),))
774        if self.cache_enable and enable_ps:
775            self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size)
776            if is_auto_parallel:
777                self.unique.add_prim_attr('cache_enable', True)
778        indices_shape_size = 2
779        if slice_mode == "field_slice" and is_auto_parallel:
780            if not manual_shapes:
781                raise ValueError(f"For '{self.cls_name}', the 'manual_shapes' should not be none "
782                                 f"when 'slice_mode' is 'field_slice'.")
783            if not isinstance(manual_shapes, tuple):
784                raise TypeError(f"For '{self.cls_name}', the type of 'manual_shapes' must be tuple(int), but got "
785                                f"type {type(manual_shapes).__name__}.")
786            for dim in manual_shapes:
787                Validator.check_positive_int(dim, 'manual shape dim', self.cls_name)
788            self.gatherv2.add_prim_attr("manual_split", manual_shapes)
789            self.embeddinglookup.add_prim_attr("manual_split", manual_shapes)
790            self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size())))
791            self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size())))
792        elif slice_mode == "table_row_slice" and is_auto_parallel:
793            full_batch = _get_full_batch()
794            if (target == 'DEVICE' and not full_batch) or (self.cache_enable and enable_ps and sparse):
795                indices_shape_size = 1
796                self.gather_revert.shard(((1, 1), (get_group_size(),)))
797                self.forward_unique = True
798            indices_strategy = (1,) * indices_shape_size
799            self.gatherv2.shard(((get_group_size(), 1), indices_strategy))
800            self.embeddinglookup.shard(((get_group_size(), 1), indices_strategy))
801        elif slice_mode == "table_column_slice" and is_auto_parallel:
802            if target == 'DEVICE':
803                indices_shape_size = 1
804                self.gather_revert.shard(((1, get_group_size()), (1,)))
805                self.forward_unique = True
806            indices_strategy = (1,) * indices_shape_size
807            self.gatherv2.shard(((1, get_group_size()), indices_strategy))
808            self.embeddinglookup.shard(((1, get_group_size()), indices_strategy))
809        elif slice_mode == "batch_slice" and is_auto_parallel:
810            indices_strategy = [get_group_size()]
811            indices_strategy.extend([1] * (indices_shape_size - 1))
812            indices_strategy = tuple(indices_strategy)
813            self.gatherv2.shard(((1, 1), indices_strategy))
814            self.embeddinglookup.shard(((1, 1), indices_strategy))
815        else:
816            if is_auto_parallel:
817                raise ValueError(f"For '{self.cls_name}', the 'slice_mode' must be one of values in "
818                                 f"['field_slice', 'table_row_slice', 'table_column_slice', 'batch_slice'], "
819                                 f"but got 'slice_mode': {slice_mode}")
820        if self.cache_enable and not enable_ps:
821            if parallel_mode != ParallelMode.STAND_ALONE:
822                raise ValueError(f"For '{self.cls_name}', the 'parallel_mode' must be equal to "
823                                 f"'ParallelMode.STAND_ALONE', but got {parallel_mode}.")
824            self._set_cache_enable()
825        self.embedding_table.unique = self.forward_unique
826        self.max_norm = max_norm
827        if self.max_norm is not None:
828            self.max_norm = Validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name)
829            self.max_norm = Tensor(self.max_norm, dtype=mstype.float16)
830
831        self.thor = True
832        self.matrix_a = Parameter(Tensor(np.zeros([vocab_size]).astype(np.float32)),
833                                  name='matrix_a', requires_grad=False)
834        self.matrix_g = Parameter(Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float32)),
835                                  name="matrix_g", requires_grad=False)
836        self.reduce_sum = P.ReduceSum(keep_dims=False)
837        self.getG = P.InsertGradientOf(self.save_gradient)
838        self.cast = P.Cast()
839        self.cube_matmul = P.MatMul(transpose_a=True)
840        self.mul = P.Mul()
841        self.on_value = Tensor(1.0, self.dtype)
842        self.off_value = Tensor(0.0, self.dtype)
843        self.one_hot = P.OneHot()
844
845
846    def save_gradient(self, dout):
847        """
848           this function only for thor optimizer
849           save_gradient
850        """
851        out = dout
852        shape = self.shape(dout)
853        normalizer = self.cast(shape[0], mstype.float16)
854        dout = self.reshape(dout, (-1, self.embedding_size))
855        matrix_g = self.cube_matmul(dout, dout)
856        matrix_g = self.mul(matrix_g, 1.0 / normalizer)
857        matrix_g = self.cast(matrix_g, mstype.float16)
858        self.matrix_g = matrix_g
859        return out
860
861    def _set_cache_enable(self):
862        """EmbeddingLookup cache check for not ps env, which is only support 'ascend'."""
863        if self.target != 'DEVICE':
864            raise ValueError(f"For '{self.cls_name}', the configuration of 'vocab_cache_size' is valid "
865                             f"only when 'target' is 'DEVICE', but got 'target': {self.target}.")
866        if not self.sparse:
867            raise ValueError(f"For '{self.cls_name}', the configuration of 'vocab_cache_size' is valid "
868                             f"only when 'sparse' is true, but got 'sparse': {self.sparse}.")
869        if context.get_context("device_target") != 'Ascend':
870            raise ValueError(f"For '{self.cls_name}', the configuration of 'vocab_cache_size' is valid "
871                             f"only when 'device_target' is 'Ascend', but got {context.get_context('device_target')}.")
872
873        logger.info("EmbeddingLookup cache enable takes effect.")
874        self.forward_unique = True
875        self.unique = P.Unique().set_device('CPU')
876        self.unique.add_prim_attr('cache_enable', True)
877        self.embedding_table.cache_enable = self.cache_enable
878        self.embedding_table.cache_shape = (self.vocab_cache_size, self.embedding_size)
879        self.reshape_first = P.Reshape().set_device('CPU')
880
881    def _process_vocab_cache(self, slice_mode):
882        """PS embeddingLookup cache check and process."""
883        self.cache_enable = False
884        if self.vocab_cache_size > 0:
885            if self.target == 'CPU':
886                logger.warning("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target, "
887                               "current target is CPU, so it will be ignored.")
888                return
889            enable_ps = _get_ps_context("enable_ps")
890            if not enable_ps:
891                logger.warning(
892                    "The configuration of 'vocab_cache_size' is valid only in parameter server trainning "
893                    "mode, current mode is not parameter server trainning mode, so it will be ignored.")
894                return
895            parallel_mode = _get_parallel_mode()
896            is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
897            if is_auto_parallel:
898                rank_size = get_group_size()
899                rank_id = get_rank()
900                full_batch = _get_full_batch()
901                if rank_size > 1 and not (full_batch and slice_mode == "table_row_slice"):
902                    raise ValueError(f"For '{self.cls_name}', the embeddingLookup cache of parameter server parallel "
903                                     f"only be used in 'full_batch' and 'table_row_slice' parallel strategy, but got "
904                                     f"'full_batch': {full_batch}, 'slice_mode': {slice_mode}.")
905                self.vocab_cache_size = self.vocab_cache_size * rank_size
906                _set_rank_id(rank_id)
907            self.cache_enable = True
908            if _is_role_worker():
909                self.vocab_size = self.vocab_cache_size
910
911    def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size):
912        """PS embeddingLookup cache enable set."""
913        self.embedding_table.cache_enable = True
914        self.embedding_table.is_param_ps = True
915        _set_cache_enable(True)
916        if self.sparse:
917            self.forward_unique = True
918        if _is_role_worker():
919            _insert_hash_table_size(self.embedding_table.name, vocab_cache_size, embedding_size, vocab_size)
920
921    def construct(self, indices):
922        if self.target == "CPU":
923            out = self.embeddinglookup(self.embedding_table, indices, 0)
924        else:
925            if self.thor:
926                if self.forward_unique:
927                    shp = self.shape(indices) + (self.embedding_size,)
928                    indices_flatten = self.reshape_first(indices, (-1,))
929                    unique_id, unique_idx = self.unique(indices_flatten)
930                    one_hot_ids = self.one_hot(indices_flatten, self.vocab_size, self.on_value, self.off_value)
931                    matrix_a = self.reduce_sum(one_hot_ids, 0)
932                    matrix_a = self.cast(matrix_a, mstype.float16)
933                    self.matrix_a = matrix_a
934                    weight_unique = self.gatherv2(self.embedding_table, unique_id, 0)
935                    out = self.getG(weight_unique)
936                    weight_flatten = self.gather_revert(weight_unique, unique_idx, 0)
937                    out = self.reshape(weight_flatten, shp)
938
939                else:
940                    indices_flatten = self.reshape_first(indices, (-1,))
941                    one_hot_ids = self.one_hot(indices_flatten, self.vocab_size, self.on_value, self.off_value)
942                    matrix_a = self.reduce_sum(one_hot_ids, 0)
943                    matrix_a = self.cast(matrix_a, mstype.float16)
944                    self.matrix_a = matrix_a
945                    out = self.gatherv2(self.embedding_table, indices, 0)
946                    out = self.getG(out)
947            else:
948                if self.forward_unique:
949                    shp = self.shape(indices) + (self.embedding_size,)
950                    indices_flatten = self.reshape_first(indices, (-1,))
951                    unique_id, unique_idx = self.unique(indices_flatten)
952                    weight_unique = self.gatherv2(self.embedding_table, unique_id, 0)
953                    weight_flatten = self.gather_revert(weight_unique, unique_idx, 0)
954                    out = self.reshape(weight_flatten, shp)
955                else:
956                    out = self.gatherv2(self.embedding_table, indices, 0)
957        if self.max_norm is not None:
958            axis = _make_axis_range(F.rank(indices), F.rank(out))
959            clip_by_norm = ClipByNorm(axis)
960            out = clip_by_norm(out, self.max_norm)
961        # We use Depend to make 'self.matrix_g' as primal graph's weight parameter,
962        # for it's used in 'save_gradient' gradient procedure.
963        return F.depend(out, self.matrix_g)
964