1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""layers for second order optimization""" 16from __future__ import absolute_import 17 18import numpy as np 19 20import mindspore.common.dtype as mstype 21import mindspore.log as logger 22from mindspore.common.tensor import Tensor 23from mindspore.common.initializer import initializer, Initializer 24from mindspore.communication.management import get_group_size, get_rank 25from mindspore.ops import operations as P 26from mindspore.ops.operations._thor_ops import ThorIm2Col 27from mindspore.common.parameter import Parameter 28from mindspore import _checkparam as Validator 29from mindspore._checkparam import twice 30from mindspore import context 31from mindspore.nn.cell import Cell 32from mindspore.nn.layer.activation import get_activation 33from mindspore.parallel._ps_context import _is_role_worker, _get_ps_context, \ 34 _set_rank_id, _insert_hash_table_size, _set_cache_enable 35from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch 36from mindspore.context import ParallelMode 37from mindspore.ops import functional as F 38from mindspore.nn.layer.basic import ClipByNorm 39from mindspore.ops.primitive import constexpr 40 41__all__ = ['DenseThor', 'Conv2dThor', 'EmbeddingThor', 'EmbeddingLookupThor'] 42 43 44class DenseThor(Cell): 45 r""" 46 The dense connected layer and saving the information needed for THOR. 47 48 Applies dense connected layer for the input and saves the information A and G in the dense connected layer 49 needed for THOR. 50 51 This layer implements the operation as: 52 53 .. math:: 54 \text{outputs} = \text{activation}(\text{inputs} * \text{kernel} + \text{bias}), 55 56 where :math:`\text{activation}` is the activation function , :math:`\text{kernel}` is a weight matrix with the same 57 data type as the inputs created by the layer, and :math:`\text{bias}` is a bias vector 58 with the same data type as the inputs created by the layer (only if has_bias is ``True`` ). 59 60 Args: 61 in_channels (int): The number of the input channels. 62 out_channels (int): The number of the output channels. 63 weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype 64 is same as `x`. The values of str refer to the function `initializer`. Default: ``'normal'`` . 65 bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is 66 same as `x`. The values of str refer to the function `initializer`. Default: ``'zeros'`` . 67 has_bias (bool): Specifies whether the layer uses a bias vector. Default: ``True`` . 68 activation (str): activate function applied to the output of the fully connected layer, eg. 'ReLU'. 69 Default: ``None`` . 70 71 Inputs: 72 - **x** (Tensor) - Tensor of shape :math:`(N, in\_channels)`. 73 74 Outputs: 75 Tensor of shape :math:`(N, out\_channels)`. 76 77 Raises: 78 ValueError: If the shape of `weight_init` or `bias_init` is incorrect. 79 80 Supported Platforms: 81 ``Ascend`` ``GPU`` 82 83 Examples: 84 >>> import mindspore as ms 85 >>> import numpy as np 86 >>> x = ms.Tensor(np.array([[1, 2, 3], [3, 4, 5]]), ms.float32) 87 >>> net = ms.nn.DenseThor(3, 4, weight_init="ones") 88 >>> output = net(x) 89 >>> print(output) 90 [[ 6. 6. 6. 6.] 91 [ 12. 12. 12. 12. ]] 92 """ 93 94 def __init__(self, 95 in_channels, 96 out_channels, 97 weight_init='normal', 98 bias_init='zeros', 99 has_bias=True, 100 activation=None): 101 """Initialize DenseThor.""" 102 super(DenseThor, self).__init__() 103 self.thor = True 104 self.in_channels = Validator.check_positive_int(in_channels, "in_channels", self.cls_name) 105 self.out_channels = Validator.check_positive_int(out_channels, "out_channels", self.cls_name) 106 self.has_bias = Validator.check_bool(has_bias, "has_bias", self.cls_name) 107 if isinstance(weight_init, Tensor): 108 if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ 109 weight_init.shape[1] != in_channels: 110 raise ValueError(f"For '{self.cls_name}', weight init shape error. The dim of 'weight_init' should " 111 f"be equal to 2, and the first dim must be equal to 'out_channels', and the " 112 f"second dim must be equal to 'in_channels'. But got 'weight_init': {weight_init}, " 113 f"'out_channels': {out_channels}, 'in_channels': {in_channels}.") 114 self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight") 115 self.bias = None 116 if self.has_bias: 117 if isinstance(bias_init, Tensor): 118 if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: 119 raise ValueError(f"For '{self.cls_name}', bias init shape error. The dim of 'bias_init' should " 120 f"be equal to 1, and the first dim must be equal to 'out_channels'. But got " 121 f"'bias_init': {bias_init}, 'out_channels': {out_channels}.") 122 self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") 123 self.bias_add = P.BiasAdd() 124 125 self.matmul = P.MatMul(transpose_b=True) 126 self.activation = get_activation(activation) 127 self.activation_flag = self.activation is not None 128 129 self.matrix_a = Parameter(Tensor(np.eye(in_channels).astype(np.float32)), 130 name='matrix_a', requires_grad=False) 131 self.matrix_g = Parameter(Tensor(np.eye(out_channels).astype(np.float32)), 132 name="matrix_g", requires_grad=False) 133 self.shape = P.Shape() 134 self.reshape = P.Reshape() 135 self.transpose = P.Transpose() 136 self.mul = P.Mul() 137 self.is_ascend = True 138 self.split_dim = 128 139 if context.get_context("device_target") == "Ascend": 140 self._process_ascend_dense_thor(out_channels, in_channels) 141 else: 142 self.is_ascend = False 143 self.cube_matmul = P.MatMul(transpose_a=True) 144 self.getG = P.InsertGradientOf(self.save_gradient) 145 146 def _process_ascend_dense_thor(self, out_channels, in_channels): 147 """process ascend dense thor""" 148 self.matmul = P.MatMul(transpose_b=True) 149 self.cube_matmul = P.CusMatMulCube(transpose_a=True) 150 self.cast = P.Cast() 151 self.is_nsp_layer = (out_channels == 2) 152 153 def save_gradient(self, dout): 154 """ 155 this function only for thor optimizer 156 save_gradient 157 """ 158 out = dout 159 if self.is_ascend: 160 if not self.is_nsp_layer: 161 shape = self.shape(dout) 162 normalizer = self.cast(shape[0], mstype.float32) 163 matrix_g = self.cube_matmul(dout, dout) 164 matrix_g = self.mul(matrix_g, 1.0 / normalizer) 165 self.matrix_g = matrix_g 166 else: 167 dout_shape = self.shape(dout) 168 normalizer = dout_shape[0] 169 matrix_g = self.cube_matmul(dout, dout) 170 matrix_g = self.mul(matrix_g, 1.0 / normalizer) 171 self.matrix_g = matrix_g 172 return out 173 174 def construct(self, x): 175 if self.thor: 176 if self.is_ascend: 177 inputs = self.cube_matmul(x, x) 178 shape = self.shape(x) 179 normalizer = self.cast(shape[0], mstype.float32) 180 matrix_a = self.mul(inputs, 1.0 / normalizer) 181 self.matrix_a = matrix_a 182 else: 183 inputs = self.cube_matmul(x, x) 184 inputs_shape = self.shape(inputs) 185 normalizer = inputs_shape[0] 186 matrix_a = self.mul(inputs, 1.0 / normalizer) 187 self.matrix_a = matrix_a 188 x = self.matmul(x, self.weight) 189 x = self.getG(x) 190 else: 191 x = self.matmul(x, self.weight) 192 if self.has_bias: 193 x = self.bias_add(x, self.bias) 194 if self.activation_flag: 195 x = self.activation(x) 196 # We use Depend to make 'self.matrix_g' as primal graph's weight parameter, 197 # for it's used in 'save_gradient' gradient procedure. 198 return F.depend(x, self.matrix_g) 199 200 def extend_repr(self): 201 s = 'input_channels={}, output_channels={}'.format(self.in_channels, self.out_channels) 202 if self.has_bias: 203 s += ', has_bias={}'.format(self.has_bias) 204 return s 205 206 207class _ConvThor(Cell): 208 """ 209 Applies a N-D convolution over an input signal composed of multiple input planes. 210 """ 211 212 def __init__(self, in_channels, out_channels, kernel_size, stride, pad_mode, 213 padding, dilation, group, has_bias, weight_init, bias_init, transposed=False): 214 """Initialize _ConvThor.""" 215 super(_ConvThor, self).__init__() 216 self.in_channels = Validator.check_positive_int(in_channels, "in_channels", self.cls_name) 217 self.out_channels = Validator.check_positive_int(out_channels, "out_channels", self.cls_name) 218 self.kernel_size = kernel_size 219 self.stride = stride 220 self.pad_mode = pad_mode 221 self.bias_init = bias_init 222 if isinstance(padding, tuple): 223 for pad in padding: 224 Validator.check_non_negative_int(pad, 'padding item', self.cls_name) 225 self.padding = padding 226 elif isinstance(padding, int): 227 Validator.check_non_negative_int(padding, 'padding', self.cls_name) 228 self.padding = padding 229 else: 230 raise TypeError(f"For '{self.cls_name}', the type of 'padding' must be int/tuple(int), but got " 231 f"{type(padding).__name__}.") 232 233 self.dilation = dilation 234 self.group = Validator.check_positive_int(group, "group", self.cls_name) 235 self.has_bias = has_bias 236 self.__validate_kernel_size(kernel_size) 237 self.__validate_stride(stride) 238 self.__validate_dilation(dilation) 239 if in_channels % group != 0: 240 raise ValueError(f"For '{self.cls_name}', the 'in_channels' must be divisible by 'group', but got " 241 f"'in_channels': {in_channels} and 'group': {group}.") 242 if out_channels % group != 0: 243 raise ValueError(f"For '{self.cls_name}', the 'out_channels' must be divisible by 'group', but got " 244 f"'out_channels': {out_channels} and 'group': {group}.") 245 if not transposed: 246 shape = [out_channels, in_channels // group, *kernel_size] 247 else: 248 shape = [in_channels, out_channels // group, *kernel_size] 249 self.weight = Parameter(initializer(weight_init, shape), name='weight') 250 251 if Validator.check_bool(has_bias, "has_bias", self.cls_name): 252 self.bias = Parameter(initializer(self.bias_init, [out_channels]), name='bias') 253 else: 254 if self.bias_init != 'zeros': 255 logger.warning("Value of 'has_bias' is False, value of 'bias_init' will be ignored.") 256 self.bias = None 257 258 def __validate_kernel_size(self, kernel_size): 259 """validate kernel size.""" 260 if (not isinstance(kernel_size[0], int)) or (not isinstance(kernel_size[1], int)) or \ 261 isinstance(kernel_size[0], bool) or isinstance(kernel_size[1], bool) or \ 262 kernel_size[0] < 1 or kernel_size[1] < 1: 263 raise ValueError(f"For '{self.cls_name}', all elements in 'kernel_size' must be int or tuple and " 264 f"equal to or greater than 1, but got 'kernel_size': {kernel_size}.") 265 266 def __validate_stride(self, stride): 267 """validate stride.""" 268 if (not isinstance(stride[0], int)) or (not isinstance(stride[1], int)) or \ 269 isinstance(stride[0], bool) or isinstance(stride[1], bool) or stride[0] < 1 or stride[1] < 1: 270 raise ValueError(f"For '{self.cls_name}', all elements in 'stride' must be int or tuple and " 271 f"equal to or greater than 1, but got 'stride': {stride}.") 272 273 def __validate_dilation(self, dilation): 274 """validate dilation.""" 275 if (not isinstance(dilation[0], int)) or (not isinstance(dilation[1], int)) or \ 276 isinstance(dilation[0], bool) or isinstance(dilation[1], bool) or dilation[0] < 1 or dilation[1] < 1: 277 raise ValueError(f"For '{self.cls_name}', all elements in 'dilation' must be int or tuple and " 278 f"equal to or greater than 1, but got 'dilation': {dilation}.") 279 280 281class Conv2dThor(_ConvThor): 282 r""" 283 2D convolution layer and saving the information needed for THOR. 284 285 286 Applies a 2D convolution over an input tensor which is typically of shape :math:`(N, C_{in}, H_{in}, W_{in})`, 287 where :math:`N` is batch size, :math:`C_{in}` is channel number, and :math:`H_{in}, W_{in})` are height and width. 288 And saves the information A and G in the 2D convolution layer needed for THOR. 289 290 For each batch of shape :math:`(C_{in}, H_{in}, W_{in})`, the formula is defined as: 291 292 293 .. math:: 294 295 out_j = \sum_{i=0}^{C_{in} - 1} ccor(W_{ij}, X_i) + b_j, 296 297 where :math:`ccor` is the cross-correlation operator, :math:`C_{in}` is the input channel number, :math:`j` ranges 298 from :math:`0` to :math:`C_{out} - 1`, :math:`W_{ij}` corresponds to the :math:`i`-th channel of the :math:`j`-th 299 filter and :math:`out_{j}` corresponds to the :math:`j`-th channel of the output. :math:`W_{ij}` is a slice 300 of kernel and it has shape :math:`(\text{ks_h}, \text{ks_w})`, where :math:`\text{ks_h}` and 301 :math:`\text{ks_w}` are the height and width of the convolution kernel. The full kernel has shape 302 :math:`(C_{out}, C_{in} // \text{group}, \text{ks_h}, \text{ks_w})`, where group is the group number 303 to split the input `x` in the channel dimension. 304 305 If the 'pad_mode' is set to be "valid", the output height and width will be 306 :math:`\left \lfloor{1 + \frac{H_{in} + 2 \times \text{padding} - \text{ks_h} - 307 (\text{ks_h} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` and 308 :math:`\left \lfloor{1 + \frac{W_{in} + 2 \times \text{padding} - \text{ks_w} - 309 (\text{ks_w} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` respectively. 310 311 Note: 312 For Ascend, the type of inputs should be subclass of Tensor[Float16], Tensor[Int8]. 313 For GPU, the type of inputs should be subclass of Tensor[Float32]. 314 315 Args: 316 in_channels (int): The number of the input channel :math:`C_{in}`. 317 out_channels (int): The number of the output channel :math:`C_{out}`. 318 kernel_size (Union[int, tuple[int]]): The data type is int or a tuple of 2 integers. Specifies the height 319 and width of the 2D convolution window. Single int means that the value is not only the height, but also 320 the width of the kernel. A tuple of 2 integers means the height and the width of the kernel respectively. 321 stride (Union[int, tuple[int]]): The distance of kernel moving, an int number represents the height and width 322 of movement, or a tuple of two int numbers that represent height and width of movement, respectively. 323 Default: ``1`` . 324 pad_mode (str): Specifies padding mode. The optional values are 325 ``"same"`` , ``"valid"`` , ``"pad"`` . Default: ``"same"`` . 326 327 - ``"same"``: Adopts the way of completion. The shape of the output will be the same as 328 the `x`. The total number of padding will be calculated in horizontal and vertical 329 directions and evenly distributed to top and bottom, left and right if possible. Otherwise, the 330 last extra padding will be done from the bottom and the right side. If this mode is set, `padding` 331 must be 0. 332 333 - ``"valid"``: Adopts the way of discarding. The possible largest height and width of output will be 334 returned without padding. Extra pixels will be discarded. If this mode is set, `padding` must be 0. 335 336 - ``"pad"``: Implicit paddings on both sides of the input `x`. The number of `padding` will be padded to 337 the input Tensor borders. `padding` must be greater than or equal to 0. 338 339 padding (Union[int, tuple[int]]): Implicit paddings on both sides of the input `x`. If `padding` is an integer, 340 the paddings of top, bottom, left and right are the same, equal to padding. If `padding` is a tuple 341 with four integers, the paddings of top, bottom, left and right will be equal to padding[0], 342 padding[1], padding[2], and padding[3] accordingly. Default: ``0`` . 343 dilation (Union[int, tuple[int]]): The data type is int or a tuple of 2 integers. Specifies the dilation rate 344 to use for dilated convolution. If set to be :math:`k > 1`, there will 345 be :math:`k - 1` pixels skipped for each sampling location. Its value must 346 be greater or equal to 1 and bounded by the height and width of the input `x`. 347 Default: ``1`` . 348 group (int): Splits filter into groups, `in_ channels` and `out_channels` must be 349 divisible by the number of groups. If the group is equal to `in_channels` and `out_channels`, 350 this 2D convolution layer also can be called 2D depthwise convolution layer. Default: ``1`` . 351 has_bias (bool): Specifies whether the layer uses a bias vector. Default: ``False`` . 352 weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializes the convolution kernel. 353 It can be a Tensor, a string, an Initializer or a number. When a string is specified, 354 values from ``'TruncatedNormal'`` , ``'Normal'`` , ``'Uniform'`` , ``'HeUniform'`` and ``'XavierUniform'`` 355 distributions as well as constant ``'One'`` and ``'Zero'`` distributions are possible. Alias 356 ``'xavier_uniform'`` , ``'he_uniform'`` , ``'ones'`` and ``'zeros'`` are acceptable. Uppercase and 357 lowercase are both acceptable. Refer to the values of Initializer for more details. Default: ``'normal'`` . 358 bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializes the bias vector. Possible 359 Initializer and string are the same as 'weight_init'. Refer to the values of 360 Initializer for more details. Default: ``'zeros'`` . 361 362 Inputs: 363 - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. 364 365 Outputs: 366 Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. 367 368 Supported Platforms: 369 ``Ascend`` ``GPU`` 370 371 Examples: 372 >>> import mindspore as ms 373 >>> import numpy as np 374 >>> net = ms.nn.Conv2dThor(120, 240, 4, has_bias=False, weight_init='normal') 375 >>> # for Ascend 376 >>> x = ms.Tensor(np.ones([1, 120, 1024, 640]), ms.float16) 377 >>> print(net(x).shape) 378 (1, 240, 1024, 640) 379 """ 380 381 def __init__(self, in_channels, out_channels, kernel_size, stride=1, 382 pad_mode='same', padding=0, dilation=1, group=1, has_bias=False, 383 weight_init='normal', bias_init='zeros'): 384 """Initialize Conv2dThor.""" 385 kernel_size = twice(kernel_size) 386 stride = twice(stride) 387 self._dilation = dilation 388 dilation = twice(dilation) 389 super(Conv2dThor, self).__init__(in_channels, out_channels, kernel_size, 390 stride, pad_mode, padding, dilation, group, has_bias, weight_init, bias_init) 391 self.conv2d = P.Conv2D(out_channel=self.out_channels, kernel_size=self.kernel_size, 392 mode=1, pad_mode=self.pad_mode, pad=self.padding, 393 stride=self.stride, dilation=self.dilation, group=self.group) 394 self._init_depthwise_conv2d(weight_init) 395 self.bias_add = P.BiasAdd() 396 self.thor = True 397 self.hw = kernel_size[0] * kernel_size[1] 398 self.matrix_a_dim = self.in_channels * self.kernel_size[0] * self.kernel_size[1] 399 self.matrix_g_dim = self.out_channels 400 self.shape = P.Shape() 401 self.reshape = P.Reshape() 402 self.mul = P.Mul() 403 self.cast = P.Cast() 404 self.a_normalizer = Parameter(initializer(1, [1], mstype.float32), name="a_normalizer", requires_grad=False) 405 self.g_normalizer = Parameter(initializer(1, [1], mstype.float32), name="g_normalizer", requires_grad=False) 406 self.is_ascend = True 407 if context.get_context("device_target") == "Ascend": 408 self._process_ascend_conv2d_thor(kernel_size, stride) 409 else: 410 self.is_ascend = False 411 self.img2col = ThorIm2Col(kernel_size=kernel_size, stride=stride, pad_mode="same") 412 self.matmul = P.MatMul(transpose_b=True) 413 self.reduce_mean = P.ReduceMean(keep_dims=False) 414 self.matrix_a_cov = Parameter(Tensor(np.zeros([self.matrix_a_dim, self.matrix_a_dim]).astype(np.float32)), 415 name='matrix_a', requires_grad=False) 416 self.matrix_g_cov = Parameter(Tensor(np.zeros([self.matrix_g_dim, self.matrix_g_dim]).astype(np.float32)), 417 name='matrix_g', requires_grad=False) 418 self.getG = P.InsertGradientOf(self.save_gradient) 419 420 def _process_ascend_conv2d_thor(self, kernel_size, stride): 421 """process ascend conv2d thor""" 422 ksizes = (1, kernel_size[0], kernel_size[1], 1) 423 strides = (1, stride[0], stride[1], 1) 424 ksizes_tbe = (kernel_size[0], kernel_size[1]) 425 self.img2col = P.CusImg2Col(ksizes=ksizes, strides=strides) 426 self.transpose = P.Transpose() 427 self.reshape = P.Reshape() 428 self.cube_matmul = P.CusMatMulCube(transpose_a=True) 429 self.diag_block_dim = 128 430 self.matrix_a_cov = Parameter(Tensor(np.eye(self.matrix_a_dim).astype(np.float32)), 431 name='matrix_a', requires_grad=False) 432 self.matrix_g_cov = Parameter(Tensor(np.eye(self.matrix_g_dim).astype(np.float32)), 433 name='matrix_g', requires_grad=False) 434 self.slice = P.Slice() 435 self.im2col = P.NewIm2Col(ksizes=ksizes_tbe, strides=stride[0], padding_mode="SAME") 436 437 def _init_depthwise_conv2d(self, weight_init): 438 """Initialize depthwise conv2d op""" 439 if context.get_context("device_target") == "Ascend" and self.group > 1: 440 self.dilation = self._dilation 441 Validator.check_int('group', self.group, self.in_channels, Validator.EQ, self.cls_name) 442 Validator.check_int('group', self.group, self.out_channels, Validator.EQ, self.cls_name) 443 self.conv2d = P.DepthwiseConv2dNative(channel_multiplier=1, 444 kernel_size=self.kernel_size, 445 pad_mode=self.pad_mode, 446 pad=self.padding, 447 stride=self.stride, 448 dilation=self.dilation) 449 weight_shape = [1, self.in_channels, *self.kernel_size] 450 self.weight_init = weight_init 451 if isinstance(weight_init, Tensor): 452 self.weight_init = weight_init.swapaxes(0, 1) 453 if isinstance(weight_init, Initializer): 454 self.weight_init.shape = weight_shape 455 self.weight = Parameter(initializer(self.weight_init, weight_shape), name='weight') 456 457 def save_gradient(self, dout): 458 """save_gradient""" 459 out = dout 460 if self.is_ascend: 461 dout_shape = self.shape(dout) 462 dout = self.transpose(dout, (0, 2, 3, 1)) 463 dout = self.reshape(dout, (-1, dout_shape[1])) 464 dout_shape = self.shape(dout) 465 normalizer = dout_shape[0] 466 matrix_g = self.cube_matmul(dout, dout) 467 normalizer = self.cast(normalizer, mstype.float32) 468 matrix_g = self.mul(matrix_g, 1.0 / normalizer) 469 self.g_normalizer = self.reshape(Tensor(normalizer), (1,)) 470 self.matrix_g_cov = matrix_g 471 else: 472 dout = self.reduce_mean(dout, 0) 473 dout_shape = self.shape(dout) 474 dout = self.reshape(dout, (dout_shape[0], -1)) 475 dout_shape = self.shape(dout) 476 normalizer = dout_shape[1] 477 dout = self.cast(dout, mstype.float32) 478 matrix_g = self.matmul(dout, dout) 479 matrix_g = self.mul(matrix_g, 1.0 / normalizer) 480 self.g_normalizer = self.reshape(Tensor(normalizer), (1,)) 481 self.matrix_g_cov = matrix_g 482 return out 483 484 def construct(self, x): 485 if self.thor: 486 if self.is_ascend: 487 matrix_a = self.im2col(x) 488 matrix_a_shape = self.shape(matrix_a) 489 y = matrix_a_shape[3] 490 matrix_a = self.reshape(matrix_a, (-1, y)) 491 matrix_a_shape = self.shape(matrix_a) 492 normalizer = matrix_a_shape[0] 493 matrix_a = self.cube_matmul(matrix_a, matrix_a) 494 normalizer = self.cast(normalizer, mstype.float32) 495 matrix_a = self.mul(matrix_a, 1.0 / normalizer) 496 self.a_normalizer = self.reshape(Tensor(normalizer), (1,)) 497 self.matrix_a_cov = matrix_a 498 weight = self.cast(self.weight, mstype.float16) 499 output = self.conv2d(x, weight) 500 output = self.getG(output) 501 else: 502 matrix_a = self.img2col(x) 503 matrix_a_shape = self.shape(matrix_a) 504 matrix_a = self.reshape(matrix_a, (matrix_a_shape[0] * matrix_a_shape[1] * matrix_a_shape[2], 505 matrix_a_shape[3], -1)) 506 matrix_a = self.reduce_mean(matrix_a, 1) 507 matrix_a_shape = self.shape(matrix_a) 508 normalizer = matrix_a_shape[1] 509 matrix_a = self.cast(matrix_a, mstype.float32) 510 matrix_a = self.matmul(matrix_a, matrix_a) 511 matrix_a = self.mul(matrix_a, 1.0 / normalizer) 512 self.a_normalizer = self.reshape(Tensor(normalizer), (1,)) 513 self.matrix_a_cov = matrix_a 514 output = self.conv2d(x, self.weight) 515 output = self.getG(output) 516 else: 517 if self.is_ascend: 518 weight = self.cast(self.weight, mstype.float16) 519 output = self.conv2d(x, weight) 520 else: 521 output = self.conv2d(x, self.weight) 522 if self.has_bias: 523 if self.is_ascend: 524 bias = self.cast(self.bias, mstype.float16) 525 output = self.bias_add(output, bias) 526 else: 527 output = self.bias_add(output, self.bias) 528 return output 529 530 def extend_repr(self): 531 s = 'input_channels={}, output_channels={}, kernel_size={}, stride={}, ' \ 532 'pad_mode={}, padding={}, dilation={}, group={}, has_bias={}, ' \ 533 'bias_init={}'.format(self.in_channels, self.out_channels, self.kernel_size, 534 self.stride, self.pad_mode, self.padding, self.dilation, 535 self.group, self.has_bias, self.bias_init) 536 return s 537 538 539class EmbeddingThor(Cell): 540 r""" 541 A simple lookup table that stores embeddings of a fixed dictionary and size 542 and saving the information needed for THOR. 543 544 This module is often used to store word embeddings and retrieve them using 545 indices. The input to the module is a list of indices, and the output is 546 the corresponding word embeddings. And saves the information A and G in the dense connected layer 547 needed for THOR. 548 549 Note: 550 When 'use_one_hot' is set to True, the type of the input `x` must be mindspore.int32. 551 552 Args: 553 vocab_size (int): The size of the dictionary of embeddings. 554 embedding_size (int): The size of each embedding vector. 555 use_one_hot (bool): Specifies whether to apply one_hot encoding form. Default: ``False`` . 556 embedding_table (Union[Tensor, str, Initializer, numbers.Number]): Initializes the embedding_table. 557 Refer to class `initializer` for the values of string when a string is specified. Default: ``'normal'`` . 558 dtype (:class:`mindspore.dtype`): Data type of input `x`. Default: ``mindspore.float32`` . 559 padding_idx (int, None): When the padding_idx encounters index, the output embedding vector of this index 560 will be initialized to zero. Default: ``None`` . The feature is inactivated. 561 Inputs: 562 - **x** (Tensor) - Tensor of input shape :math:`(\text{batch_size}, \text{x_length})`. The elements of 563 the Tensor must be integer and not larger than vocab_size. Otherwise the corresponding embedding vector will 564 be zero. 565 566 Outputs: 567 Tensor of output shape :math:`(\text{batch_size}, \text{x_length}, \text{embedding_size})`. 568 569 Supported Platforms: 570 ``Ascend`` ``GPU`` 571 572 Examples: 573 >>> import mindspore as ms 574 >>> import numpy as np 575 >>> net = ms.nn.EmbeddingThor(20000, 768, True) 576 >>> x = ms.Tensor(np.ones([8, 128]), ms.int32) 577 >>> 578 >>> # Maps the input word IDs to word embedding. 579 >>> output = net(x) 580 >>> output.shape 581 (8, 128, 768) 582 """ 583 584 def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal', 585 dtype=mstype.float32, padding_idx=None): 586 """Initialize EmbeddingThor.""" 587 super(EmbeddingThor, self).__init__() 588 self.vocab_size = Validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name) 589 self.embedding_size = Validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name) 590 Validator.check_value_type('use_one_hot', use_one_hot, [bool], self.cls_name) 591 Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) 592 self.use_one_hot = use_one_hot 593 self.dtype = dtype 594 self.init_tensor = initializer(embedding_table, [vocab_size, embedding_size]) 595 self.padding_idx = padding_idx 596 if padding_idx is not None: 597 self.padding_idx = Validator.check_int_range(padding_idx, 0, vocab_size, Validator.INC_BOTH, 598 "padding_idx", self.cls_name) 599 self.init_tensor[self.padding_idx] = 0 600 self.embedding_table = Parameter(self.init_tensor, name='embedding_table') 601 self.expand = P.ExpandDims() 602 self.reshape_flat = P.Reshape() 603 self.shp_flat = (-1,) 604 self.gather = P.Gather() 605 self.one_hot = P.OneHot() 606 self.on_value = Tensor(1.0, self.dtype) 607 self.off_value = Tensor(0.0, self.dtype) 608 self.array_mul = P.MatMul() 609 self.reshape = P.Reshape() 610 self.get_shp = P.Shape() 611 self.thor = True 612 self.matrix_a = Parameter(Tensor(np.zeros([vocab_size]).astype(np.float32)), 613 name='matrix_a', requires_grad=False) 614 self.matrix_g = Parameter(Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float32)), 615 name="matrix_g", requires_grad=False) 616 self.reduce_sum = P.ReduceSum(keep_dims=False) 617 self.getG = P.InsertGradientOf(self.save_gradient) 618 self.cast = P.Cast() 619 if context.get_context("device_target") == "Ascend": 620 self.cube_matmul = P.CusMatMulCube(transpose_a=True) 621 else: 622 self.cube_matmul = P.MatMul(transpose_a=True) 623 self.mul = P.Mul() 624 625 def save_gradient(self, dout): 626 """ 627 this function only for thor optimizer 628 save_gradient 629 """ 630 out = dout 631 shape = self.get_shp(dout) 632 normalizer = self.cast(shape[0], mstype.float32) 633 matrix_g = self.cube_matmul(dout, dout) 634 matrix_g = self.mul(matrix_g, 1.0 / normalizer) 635 self.matrix_g = matrix_g 636 return out 637 638 def construct(self, ids): 639 extended_ids = self.expand(ids, -1) 640 out_shape = self.get_shp(ids) + (self.embedding_size,) 641 flat_ids = self.reshape_flat(extended_ids, self.shp_flat) 642 643 if self.use_one_hot: 644 one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value) 645 output_for_reshape = self.array_mul(one_hot_ids, self.embedding_table) 646 else: 647 if self.thor: 648 one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value) 649 matrix_a = self.reduce_sum(one_hot_ids, 0) 650 self.matrix_a = matrix_a 651 output_for_reshape = self.gather(self.embedding_table, flat_ids, 0) 652 output_for_reshape = self.getG(output_for_reshape) 653 else: 654 output_for_reshape = self.gather(self.embedding_table, flat_ids, 0) 655 656 output = self.reshape(output_for_reshape, out_shape) 657 # We use Depend to make 'self.matrix_g' as primal graph's weight parameter, 658 # for it's used in 'save_gradient' gradient procedure. 659 return F.depend(output, self.matrix_g) 660 661 def extend_repr(self): 662 s = 'vocab_size={}, embedding_size={}, use_one_hot={}, embedding_table={}, dtype={}, padding_idx={}'.format( 663 self.vocab_size, self.embedding_size, self.use_one_hot, self.embedding_table, self.dtype, self.padding_idx) 664 return s 665 666 667@constexpr 668def _make_axis_range(start, end): 669 axis = tuple(range(start, end)) 670 return axis 671 672 673class EmbeddingLookupThor(Cell): 674 r""" 675 Returns a slice of the input tensor based on the specified indices 676 and saving the information needed for THOR. 677 678 This module has the same function as EmbeddingLookup, but additionally saves the information A and G in the 679 embeddinglookup layer needed for THOR. 680 681 682 Args: 683 vocab_size (int): The size of the dictionary of embeddings. 684 embedding_size (int): The size of each embedding vector. 685 param_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table. 686 Refer to class `initializer` for the values of string when a string is specified. 687 Default: ``'normal'`` . 688 target (str): Specifies the target where the op is executed. The value must in 689 [ ``'DEVICE'`` , ``'CPU'`` ]. Default: ``'CPU'`` . 690 slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value must get through 691 nn.EmbeddingLookup. Default: nn.EmbeddingLookup.BATCH_SLICE. 692 manual_shapes (tuple): The accompaniment array in field slice mode. 693 max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32 or None. 694 Default: ``None`` . 695 sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be ``true`` . 696 Default: ``True`` . 697 vocab_cache_size (int): Cache size of the dictionary of embeddings. Default: ``0`` . It is valid only in 698 'DEVICE' target. And the moment parameter of corresponding optimizer will also be set to the cache size. 699 In addition, it should be noted that it will cost the 'DEVICE' memory, so suggests setting a reasonable 700 value to avoid insufficient memory. 701 702 Inputs: 703 - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. 704 705 Outputs: 706 Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`. 707 708 Raises: 709 ValueError: If `target` is neither 'CPU' nor 'DEVICE'. 710 ValueError: If `slice_mode` is not one of 'batch_slice' or 'field_slice' or 711 'table_row_slice' or 'table_column_slice'. 712 ValueError: If `sparse` is False and `target` is 'CPU'. 713 ValueError: If `slice_mode` is 'field_slice' and `manual_shapes` is None. 714 TypeError: If `vocab_size` or `embedding_size` or `vocab_cache_size` is not an int. 715 TypeError: If `sparse` is not a bool or `manual_shapes` is not a tuple. 716 ValueError: If `vocab_size` or `embedding_size` is less than 1. 717 ValueError: If `vocab_cache_size` is less than 0. 718 719 720 Supported Platforms: 721 ``Ascend`` 722 723 Examples: 724 >>> import mindspore as ms 725 >>> import numpy as np 726 >>> input_indices = ms.Tensor(np.array([[1, 0], [3, 2]]), ms.int32) 727 >>> result = ms.nn.EmbeddingLookup(4,2)(input_indices) 728 >>> print(result.shape) 729 (2, 2, 2) 730 """ 731 BATCH_SLICE = "batch_slice" 732 FIELD_SLICE = "field_slice" 733 TABLE_ROW_SLICE = "table_row_slice" 734 TABLE_COLUMN_SLICE = "table_column_slice" 735 736 def __init__(self, vocab_size, embedding_size, param_init='normal', 737 target='CPU', slice_mode='batch_slice', manual_shapes=None, 738 max_norm=None, sparse=True, vocab_cache_size=0): 739 super(EmbeddingLookupThor, self).__init__() 740 Validator.check_value_type('sparse', sparse, [bool], self.cls_name) 741 self.vocab_size = Validator.check_positive_int(vocab_size, 'vocab_size', self.cls_name) 742 self.vocab_cache_size = Validator.check_non_negative_int(vocab_cache_size, 'vocab_cache_size', self.cls_name) 743 self.target = target 744 self.sparse = sparse 745 self.cache_enable = self.vocab_cache_size > 0 746 self.forward_unique = False 747 self.dtype = mstype.float16 748 if target not in ('CPU', 'DEVICE'): 749 raise ValueError(f"For '{self.cls_name}', the 'target' must be one of values in ('CPU', 'DEVICE'), " 750 f"but got {target}.") 751 if not sparse and target == 'CPU': 752 raise ValueError(f"For '{self.cls_name}', embedding_lookup must be sparse when 'target' is CPU, but got " 753 f"'sparse': {sparse}, 'target': {target}.") 754 if sparse: 755 self.gatherv2 = P.SparseGatherV2() 756 else: 757 self.gatherv2 = P.Gather() 758 self.embeddinglookup = P.EmbeddingLookup().set_device('CPU') 759 enable_ps = _get_ps_context("enable_ps") 760 if enable_ps: 761 self._process_vocab_cache(slice_mode) 762 self.embedding_size = Validator.check_positive_int(embedding_size, 'embedding_size', self.cls_name) 763 self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size], 764 mstype.float16), name='embedding_table') 765 parallel_mode = _get_parallel_mode() 766 is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) 767 self.gather_revert = P.Gather() 768 self.reshape_first = P.Reshape() 769 self.reshape = P.Reshape() 770 self.unique = P.Unique() 771 self.shape = P.Shape() 772 if is_auto_parallel: 773 self.unique = P.Unique().shard(((1,),)) 774 if self.cache_enable and enable_ps: 775 self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size) 776 if is_auto_parallel: 777 self.unique.add_prim_attr('cache_enable', True) 778 indices_shape_size = 2 779 if slice_mode == "field_slice" and is_auto_parallel: 780 if not manual_shapes: 781 raise ValueError(f"For '{self.cls_name}', the 'manual_shapes' should not be none " 782 f"when 'slice_mode' is 'field_slice'.") 783 if not isinstance(manual_shapes, tuple): 784 raise TypeError(f"For '{self.cls_name}', the type of 'manual_shapes' must be tuple(int), but got " 785 f"type {type(manual_shapes).__name__}.") 786 for dim in manual_shapes: 787 Validator.check_positive_int(dim, 'manual shape dim', self.cls_name) 788 self.gatherv2.add_prim_attr("manual_split", manual_shapes) 789 self.embeddinglookup.add_prim_attr("manual_split", manual_shapes) 790 self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size()))) 791 self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size()))) 792 elif slice_mode == "table_row_slice" and is_auto_parallel: 793 full_batch = _get_full_batch() 794 if (target == 'DEVICE' and not full_batch) or (self.cache_enable and enable_ps and sparse): 795 indices_shape_size = 1 796 self.gather_revert.shard(((1, 1), (get_group_size(),))) 797 self.forward_unique = True 798 indices_strategy = (1,) * indices_shape_size 799 self.gatherv2.shard(((get_group_size(), 1), indices_strategy)) 800 self.embeddinglookup.shard(((get_group_size(), 1), indices_strategy)) 801 elif slice_mode == "table_column_slice" and is_auto_parallel: 802 if target == 'DEVICE': 803 indices_shape_size = 1 804 self.gather_revert.shard(((1, get_group_size()), (1,))) 805 self.forward_unique = True 806 indices_strategy = (1,) * indices_shape_size 807 self.gatherv2.shard(((1, get_group_size()), indices_strategy)) 808 self.embeddinglookup.shard(((1, get_group_size()), indices_strategy)) 809 elif slice_mode == "batch_slice" and is_auto_parallel: 810 indices_strategy = [get_group_size()] 811 indices_strategy.extend([1] * (indices_shape_size - 1)) 812 indices_strategy = tuple(indices_strategy) 813 self.gatherv2.shard(((1, 1), indices_strategy)) 814 self.embeddinglookup.shard(((1, 1), indices_strategy)) 815 else: 816 if is_auto_parallel: 817 raise ValueError(f"For '{self.cls_name}', the 'slice_mode' must be one of values in " 818 f"['field_slice', 'table_row_slice', 'table_column_slice', 'batch_slice'], " 819 f"but got 'slice_mode': {slice_mode}") 820 if self.cache_enable and not enable_ps: 821 if parallel_mode != ParallelMode.STAND_ALONE: 822 raise ValueError(f"For '{self.cls_name}', the 'parallel_mode' must be equal to " 823 f"'ParallelMode.STAND_ALONE', but got {parallel_mode}.") 824 self._set_cache_enable() 825 self.embedding_table.unique = self.forward_unique 826 self.max_norm = max_norm 827 if self.max_norm is not None: 828 self.max_norm = Validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name) 829 self.max_norm = Tensor(self.max_norm, dtype=mstype.float16) 830 831 self.thor = True 832 self.matrix_a = Parameter(Tensor(np.zeros([vocab_size]).astype(np.float32)), 833 name='matrix_a', requires_grad=False) 834 self.matrix_g = Parameter(Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float32)), 835 name="matrix_g", requires_grad=False) 836 self.reduce_sum = P.ReduceSum(keep_dims=False) 837 self.getG = P.InsertGradientOf(self.save_gradient) 838 self.cast = P.Cast() 839 self.cube_matmul = P.MatMul(transpose_a=True) 840 self.mul = P.Mul() 841 self.on_value = Tensor(1.0, self.dtype) 842 self.off_value = Tensor(0.0, self.dtype) 843 self.one_hot = P.OneHot() 844 845 846 def save_gradient(self, dout): 847 """ 848 this function only for thor optimizer 849 save_gradient 850 """ 851 out = dout 852 shape = self.shape(dout) 853 normalizer = self.cast(shape[0], mstype.float16) 854 dout = self.reshape(dout, (-1, self.embedding_size)) 855 matrix_g = self.cube_matmul(dout, dout) 856 matrix_g = self.mul(matrix_g, 1.0 / normalizer) 857 matrix_g = self.cast(matrix_g, mstype.float16) 858 self.matrix_g = matrix_g 859 return out 860 861 def _set_cache_enable(self): 862 """EmbeddingLookup cache check for not ps env, which is only support 'ascend'.""" 863 if self.target != 'DEVICE': 864 raise ValueError(f"For '{self.cls_name}', the configuration of 'vocab_cache_size' is valid " 865 f"only when 'target' is 'DEVICE', but got 'target': {self.target}.") 866 if not self.sparse: 867 raise ValueError(f"For '{self.cls_name}', the configuration of 'vocab_cache_size' is valid " 868 f"only when 'sparse' is true, but got 'sparse': {self.sparse}.") 869 if context.get_context("device_target") != 'Ascend': 870 raise ValueError(f"For '{self.cls_name}', the configuration of 'vocab_cache_size' is valid " 871 f"only when 'device_target' is 'Ascend', but got {context.get_context('device_target')}.") 872 873 logger.info("EmbeddingLookup cache enable takes effect.") 874 self.forward_unique = True 875 self.unique = P.Unique().set_device('CPU') 876 self.unique.add_prim_attr('cache_enable', True) 877 self.embedding_table.cache_enable = self.cache_enable 878 self.embedding_table.cache_shape = (self.vocab_cache_size, self.embedding_size) 879 self.reshape_first = P.Reshape().set_device('CPU') 880 881 def _process_vocab_cache(self, slice_mode): 882 """PS embeddingLookup cache check and process.""" 883 self.cache_enable = False 884 if self.vocab_cache_size > 0: 885 if self.target == 'CPU': 886 logger.warning("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target, " 887 "current target is CPU, so it will be ignored.") 888 return 889 enable_ps = _get_ps_context("enable_ps") 890 if not enable_ps: 891 logger.warning( 892 "The configuration of 'vocab_cache_size' is valid only in parameter server trainning " 893 "mode, current mode is not parameter server trainning mode, so it will be ignored.") 894 return 895 parallel_mode = _get_parallel_mode() 896 is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) 897 if is_auto_parallel: 898 rank_size = get_group_size() 899 rank_id = get_rank() 900 full_batch = _get_full_batch() 901 if rank_size > 1 and not (full_batch and slice_mode == "table_row_slice"): 902 raise ValueError(f"For '{self.cls_name}', the embeddingLookup cache of parameter server parallel " 903 f"only be used in 'full_batch' and 'table_row_slice' parallel strategy, but got " 904 f"'full_batch': {full_batch}, 'slice_mode': {slice_mode}.") 905 self.vocab_cache_size = self.vocab_cache_size * rank_size 906 _set_rank_id(rank_id) 907 self.cache_enable = True 908 if _is_role_worker(): 909 self.vocab_size = self.vocab_cache_size 910 911 def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size): 912 """PS embeddingLookup cache enable set.""" 913 self.embedding_table.cache_enable = True 914 self.embedding_table.is_param_ps = True 915 _set_cache_enable(True) 916 if self.sparse: 917 self.forward_unique = True 918 if _is_role_worker(): 919 _insert_hash_table_size(self.embedding_table.name, vocab_cache_size, embedding_size, vocab_size) 920 921 def construct(self, indices): 922 if self.target == "CPU": 923 out = self.embeddinglookup(self.embedding_table, indices, 0) 924 else: 925 if self.thor: 926 if self.forward_unique: 927 shp = self.shape(indices) + (self.embedding_size,) 928 indices_flatten = self.reshape_first(indices, (-1,)) 929 unique_id, unique_idx = self.unique(indices_flatten) 930 one_hot_ids = self.one_hot(indices_flatten, self.vocab_size, self.on_value, self.off_value) 931 matrix_a = self.reduce_sum(one_hot_ids, 0) 932 matrix_a = self.cast(matrix_a, mstype.float16) 933 self.matrix_a = matrix_a 934 weight_unique = self.gatherv2(self.embedding_table, unique_id, 0) 935 out = self.getG(weight_unique) 936 weight_flatten = self.gather_revert(weight_unique, unique_idx, 0) 937 out = self.reshape(weight_flatten, shp) 938 939 else: 940 indices_flatten = self.reshape_first(indices, (-1,)) 941 one_hot_ids = self.one_hot(indices_flatten, self.vocab_size, self.on_value, self.off_value) 942 matrix_a = self.reduce_sum(one_hot_ids, 0) 943 matrix_a = self.cast(matrix_a, mstype.float16) 944 self.matrix_a = matrix_a 945 out = self.gatherv2(self.embedding_table, indices, 0) 946 out = self.getG(out) 947 else: 948 if self.forward_unique: 949 shp = self.shape(indices) + (self.embedding_size,) 950 indices_flatten = self.reshape_first(indices, (-1,)) 951 unique_id, unique_idx = self.unique(indices_flatten) 952 weight_unique = self.gatherv2(self.embedding_table, unique_id, 0) 953 weight_flatten = self.gather_revert(weight_unique, unique_idx, 0) 954 out = self.reshape(weight_flatten, shp) 955 else: 956 out = self.gatherv2(self.embedding_table, indices, 0) 957 if self.max_norm is not None: 958 axis = _make_axis_range(F.rank(indices), F.rank(out)) 959 clip_by_norm = ClipByNorm(axis) 960 out = clip_by_norm(out, self.max_norm) 961 # We use Depend to make 'self.matrix_g' as primal graph's weight parameter, 962 # for it's used in 'save_gradient' gradient procedure. 963 return F.depend(out, self.matrix_g) 964