• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""thor_ops"""
16import math
17
18from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
19from mindspore.common import dtype as mstype
20from mindspore import _checkparam as validator
21from mindspore.ops.operations.nn_ops import _check_positive_int_or_tuple
22
23__all__ = ["CusBatchMatMul",
24           "CusCholeskyTrsm",
25           "CusFusedAbsMax1",
26           "CusImg2Col",
27           "CusMatMulCubeDenseLeft",
28           "CusMatMulCubeFraczRightMul",
29           "CusMatMulCube",
30           "CusMatrixCombine",
31           "CusTranspose02314",
32           "CusMatMulCubeDenseRight",
33           "CusMatMulCubeFraczLeftCast",
34           "LoadIm2Col"
35           ]
36
37
38class CusBatchMatMul(PrimitiveWithInfer):
39    """
40    Multiplies matrix `a` by matrix `b` in batch.
41
42    The rank of input tensors must be `3`.
43
44    Inputs:
45        - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, D, D)`.
46        - **input_y** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(N, D, D)`. If
47          `transpose_b` is True.
48
49    Outputs:
50        Tensor, the shape of the output tensor is :math:`(N, D, D)`.
51
52    Examples:
53        >>> input_x = Tensor(np.ones(shape=[2, 128, 128]), mindspore.float32)
54        >>> input_y = Tensor(np.ones(shape=[2, 128, 128]), mindspore.float32)
55        >>> cus_batch_matmul = ops.CusBatchMatMul()
56        >>> output = cus_batch_matmul(input_x, input_y)
57    """
58
59    @prim_attr_register
60    def __init__(self):
61        """Initialize CusBatchMatMul"""
62        self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y'])
63        from mindspore.ops._op_impl._custom_op.batch_matmul_impl import cus_batch_matmul
64
65    def infer_shape(self, data1_shape, data2_shape):
66        return data1_shape
67
68    def infer_dtype(self, data1_dtype, data2_dtype):
69        return data1_dtype
70
71
72class CusCholeskyTrsm(PrimitiveWithInfer):
73    r"""
74    L * LT = A.
75    LT * (LT)^-1 = I.
76    return (LT)^-1.
77    Only compute the res of the diag part of input matrix with dim 128.
78    The rank of input tensors must be `2`.
79
80    Inputs:
81        - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, N)`.
82
83    Outputs:
84        Tensor, the shape of the output tensor is :math:`(N // Split\_dim, Split\_dim, Split\_dim)`.
85
86    Examples:
87        >>> input_x = Tensor(np.ones(shape=[256, 256]), mindspore.float32)
88        >>> cus_choleskytrsm = ops.CusCholeskyTrsm()
89        >>> output = cus_choleskytrsm(input_x)
90    """
91
92    @prim_attr_register
93    def __init__(self):
94        """Initialize CusCholeskyTrsm"""
95        self.init_prim_io_names(inputs=['x1'], outputs=['y'])
96        from mindspore.ops._op_impl._custom_op.cholesky_trsm_impl import cus_cholesky_trsm
97
98    def infer_shape(self, data1_shape):
99        ll = []
100        m, _ = data1_shape
101        if m >= 128:
102            ll = [m // 128, 128, 128]
103        else:
104            ll = [1, 64, 64]
105        return ll
106
107    def infer_dtype(self, data1_dtype):
108        return data1_dtype
109
110
111class CusFusedAbsMax1(PrimitiveWithInfer):
112    """
113    Computes the abs max of Tensor input.
114
115    The rank of input tensors must be `4` or `2`.
116    Inputs:
117        - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N0, M0, N1, M1)`
118          or math:`(32, 64)`.
119    Outputs:
120        Tensor, the shape of the output tensor is :math:`(32, 64)` or math:`(1, )`.
121
122    Examples:
123        >>> input_x = Tensor(np.ones(shape=[1, 3]), mindspore.float32)
124        >>> cus_fused_abs_max1 = ops.CusFusedAbsMax1()
125        >>> output = cus_fused_abs_max1(input_x)
126    """
127
128    @prim_attr_register
129    def __init__(self, origin_shape=(-1, -1)):
130        """Initialize CusFusedAbsMax1"""
131        self.init_prim_io_names(inputs=['x1'], outputs=['y'])
132        self.origin_shape = origin_shape
133        from mindspore.ops._op_impl._custom_op.fused_abs_max1_impl import cus_fused_abs_max1
134
135    def infer_shape(self, data1_shape):
136        ll = []
137        if len(data1_shape) == 2:
138            ll = [1]
139        else:
140            ll = [32, 64]
141        return ll
142
143    def infer_dtype(self, data1_dtype):
144        return data1_dtype
145
146
147class CusImg2Col(PrimitiveWithInfer):
148    """
149    Img2cols the feature map and the result in reorganized in NC1HWC0.
150
151    Args:
152        - **strides** (listInt) - the stride of the ops.
153        - **ksizes** (listInt) - the kernel size of the ops.
154    Inputs:
155        - **input_x** (Tensor) - The shape of the tensor is :math:`(N, C, H, W)`.
156    Outputs:
157        Tensor, the shape of the output tensor is :math:`(N * H_O * W_O, C1 * K_W * K_H * C0)`.
158    Examples:
159        >>> input_x = Tensor(np.ones(shape=[32, 3, 224, 224]), mindspore.float16)
160        >>> cusimg2col = ops.CusImg2Col()
161        >>> output = cusimg2col(input_x)
162    """
163
164    @prim_attr_register
165    def __init__(self, ksizes, strides, dilates=(1, 1, 1, 1), mode="NC1HWC0"):
166        """Initialize CusImg2Col"""
167        self.init_prim_io_names(inputs=['x1'], outputs=['y'])
168        self.ksizes = ksizes
169        self.strides = strides
170        self.dilates = dilates
171        self.mode = mode
172        from mindspore.ops._op_impl._custom_op.img2col_impl import cus_img2col
173
174    def infer_shape(self, data1_shape):
175        bs, c, h, w = data1_shape
176        _, stride_h, stride_w, _ = self.strides
177        _, k_w, k_h, _ = self.ksizes
178        c0 = 16
179        c1 = c // 16
180        if c1 == 0:
181            c1 = 1
182        shape = [bs * int(h // stride_h) * int(w // stride_w), k_w * k_h * c1 * c0]
183        return shape
184
185    def infer_dtype(self, data1_dtype):
186        return data1_dtype
187
188
189class CusMatMulCubeDenseLeft(PrimitiveWithInfer):
190    """
191    Multiplies matrix `a` by matrix `b`.
192
193    The rank of input_x1 must be `4`, the fractal format of the normal matrix.
194    The rank of input_x2 must be `2`.
195
196    Inputs:
197        - **input_x1** (Tensor) - The first tensor to be multiplied.
198          The shape of the tensor is :math:`(N0, M0, N1, M1)`.
199        - **input_x2** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(M, C)`.
200    Outputs:
201        Tensor, the shape of the output tensor is :math:`(N, C)`.
202    Examples:
203        >>> input_x = Tensor(np.ones(shape=[16, 16, 16, 16]), mindspore.float16)
204        >>> input_y = Tensor(np.ones(shape=[256, 256]), mindspore.float16)
205        >>> matmulcubedenseleft = ops.CusMatMulCubeDenseLeft()
206        >>> output = matmulcubedenseleft(input_x, input_y)
207    """
208
209    @prim_attr_register
210    def __init__(self):
211        """Initialize CusMatMulCubeDenseLeft"""
212        self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y'])
213        from mindspore.ops._op_impl._custom_op.matmul_cube_dense_left_impl import cus_matmul_cube_dense_left
214
215    def infer_shape(self, data1_shape, data2_shape):
216        return data2_shape
217
218    def infer_dtype(self, data1_dtype, data2_dtype):
219        return mstype.float16
220
221
222class CusMatMulCubeFraczRightMul(PrimitiveWithInfer):
223    """
224    Multiplies matrix `a` by matrix `b` and muls the result by scalar `c`.
225
226    The rank of input_x1 tensors must be `2`.
227    The rank of input_x2 tensors must be `4`.
228
229    Inputs:
230        - **input_x1** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, C)`.
231        - **input_x2** (Tensor) - The second tensor to be multiplied.
232          The shape of the tensor is :math:`(C1, M1, C0, M0)`.
233        - **input_x3** (Tensor) - The third tensor to be multiplied. The shape of the tensor if :math`(1, )`.
234    Outputs:
235        Tensor, the shape of the output tensor is :math:`(N, M)`.
236    Examples:
237        >>> input_x1 = Tensor(np.ones(shape=[256, 256]), mindspore.float16)
238        >>> input_x2 = Tensor(np.ones(shape=[16, 16, 16, 16]), mindspore.float16)
239        >>> input_x3 = Tensor(np.ones(shape=[1, ]), mindspore.float16)
240        >>> cusmatmulfraczrightmul = ops.CusMatMulCubeFraczRightMul()
241        >>> output = cusmatmulfraczrightmul(input_x1, input_x2, input_x3)
242    """
243
244    @prim_attr_register
245    def __init__(self):
246        """Initialize CusMatMulCubeFraczRightMul"""
247        self.init_prim_io_names(inputs=['x1', 'x2', 'x3'], outputs=['y'])
248        from mindspore.ops._op_impl._custom_op.matmul_cube_fracz_right_mul_impl import cus_matmul_cube_fraczrightmul
249
250    def infer_shape(self, data1_shape, data2_shape, data3_shape):
251        return data1_shape
252
253    def infer_dtype(self, data1_dtype, data2_dtype, data3_dtype):
254        return mstype.float32
255
256
257class CusMatMulCube(PrimitiveWithInfer):
258    """
259    Multiplies matrix `a` by matrix `b`.
260
261    The rank of input tensors must be `2`.
262
263    Args:
264        transpose_a (bool): If true, `a` is transposed before multiplication. Default: ``False``.
265        transpose_b (bool): If true, `b` is transposed before multiplication. Default: ``False``.
266
267    Inputs:
268        - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, C)`. If
269          `transpose_a` is True, its shape must be :math:`(N, C)` after transposing.
270        - **input_y** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(C, M)`. If
271          `transpose_b` is True, its shape must be :math:`(C, M)` after transpose.
272
273    Outputs:
274        Tensor, the shape of the output tensor is :math:`(N, M)`.
275
276    Examples:
277        >>> input_x = Tensor(np.ones(shape=[256, 256]), mindspore.float16)
278        >>> input_y = Tensor(np.ones(shape=[256, 256]), mindspore.float16)
279        >>> cusmatmulcube = ops.CusMatMulCube()
280        >>> output = matmul(input_x, input_y)
281    """
282
283    @prim_attr_register
284    def __init__(self, transpose_a=False, transpose_b=False):
285        """Initialize CusMatMulCube"""
286        self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y'])
287        self.transpose_a = transpose_a
288        self.transpose_b = transpose_b
289        from mindspore.ops._op_impl._custom_op.matmul_cube_impl import cus_matmul_cube
290
291    def infer_shape(self, data1_shape, data2_shape):
292        if self.transpose_a:
293            _, m = data1_shape
294        else:
295            m, _ = data1_shape
296        if self.transpose_b:
297            n, _ = data2_shape
298        else:
299            _, n = data2_shape
300        shape = [m, n]
301        return shape
302
303    def infer_dtype(self, data1_dtype, data2_dtype):
304        return mstype.float32
305
306
307class CusMatrixCombine(PrimitiveWithInfer):
308    """
309    move the batch matrix to result matrix diag part.
310    The rank of input tensors must be `3`.
311
312    Inputs:
313        - **input_x** (Tensor) - The shape of the tensor is :math:`(N, D, D)`.
314
315    Outputs:
316        Tensor, the shape of the output tensor is :math:`(N * D, N * D)`.
317
318    Examples:
319        >>> input_x = Tensor(np.ones(shape=[2, 128, 128]), mindspore.float32)
320        >>> cusmatrixcombine = ops.CusMatrixCombine()
321        >>> output = cusmatrixcombine(input_x)
322    """
323
324    @prim_attr_register
325    def __init__(self):
326        """Initialize CusMatrixCombine"""
327        self.init_prim_io_names(inputs=['x'], outputs=['y'])
328        from mindspore.ops._op_impl._custom_op.matrix_combine_impl import cus_matrix_combine
329
330    def infer_shape(self, data_shape):
331        a, b, c = data_shape
332        shape = [a * b, a * c]
333
334        return shape
335
336    def infer_dtype(self, data_dtype):
337        return data_dtype
338
339
340class CusTranspose02314(PrimitiveWithInfer):
341    """
342    Permute input tensor with perm (0, 2, 3, 1, 4)
343
344    The rank of input tensors must be `5` with format NC1HWC0.
345
346    Inputs:
347        - **input_x** (Tensor) - The shape of the tensor is :math:`(N, C1, H, W, C0)`.
348
349    Outputs:
350        Tensor, the shape of the output tensor is :math:`(N, H, W, C1, C0)`.
351
352    Examples:
353        >>> input_x = Tensor(np.ones(shape=[32, 1, 224, 224, 16]), mindspore.float16)
354        >>> custranspose02314 = ops.CusTranspose02314()
355        >>> output = custranspose02314(input_x)
356    """
357
358    @prim_attr_register
359    def __init__(self):
360        """Initialize CusTranspose02314"""
361        self.init_prim_io_names(inputs=['x1'], outputs=['y'])
362        from mindspore.ops._op_impl._custom_op.transpose02314_impl import cus_transpose02314
363
364    def get_bprop(self):
365        """Get backprop for CusTranspose02314."""
366
367        def bprop(x, out, dout):
368            return (C.zeros_like(x),)
369
370        return bprop
371
372    def infer_shape(self, data1_shape):
373        n, c, h, w = data1_shape
374        c0 = 16
375        c1 = c // 16
376        shape = (n * h * w, c1 * c0)
377        return shape
378
379    def infer_dtype(self, data1_dtype):
380        return data1_dtype
381
382
383class CusMatMulCubeDenseRight(PrimitiveWithInfer):
384    """
385    Multiplies matrix `a` by matrix `b`.
386
387    The rank of input_x1 tensor must be `2`.
388    The rank of input_x2 tensor must be `4`.
389
390    Inputs:
391        - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, C)`.
392        - **input_y** (Tensor) - The second tensor to be multiplied.
393          The shape of the tensor is :math:`(C1, M1, M0, C0)`.
394
395    Outputs:
396        Tensor, the shape of the output tensor is :math:`(N, M)`.
397
398    Examples:
399        >>> input_x = Tensor(np.ones(shape=[256, 256]), mindspore.float16)
400        >>> input_y = Tensor(np.ones(shape=[16, 16, 16, 16]), mindspore.float16)
401        >>> cusmatmulcubedenseright = ops.CusMatMulCubeDenseRight()
402        >>> output = cusmatmulcubedenseright(input_x, input_y)
403    """
404
405    @prim_attr_register
406    def __init__(self):
407        """Initialize CusMatMulCubeDenseRight"""
408        self.init_prim_io_names(inputs=['x1', 'x2', 'x3'], outputs=['y'])
409        from mindspore.ops._op_impl._custom_op.matmul_cube_dense_right_impl import cus_matmul_cube_dense_right
410
411    def infer_shape(self, data1_shape, data2_shape, data3_shape):
412        return data1_shape
413
414    def infer_dtype(self, data1_dtype, data2_dtype, data3_dtype):
415        return mstype.float32
416
417
418class CusMatMulCubeFraczLeftCast(PrimitiveWithInfer):
419    """
420    Multiplies matrix `a` by matrix `b`.
421
422    The rank of input_x1 tensor must be `4`.
423    The rank of input_x2 tensors must be `2`.
424
425    Inputs:
426        - **input_x1** (Tensor) - The first tensor to be multiplied.
427          The shape of the tensor is :math:`(C1, N1, N0, C0)`.
428        - **input_x2** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(C, M)`.
429
430    Outputs:
431        Tensor, the shape of the output tensor is :math:`(N, M)`.
432
433    Examples:
434        >>> input_x = Tensor(np.ones(shape=[16, 16, 16, 16]), mindspore.float16)
435        >>> input_y = Tensor(np.ones(shape=[256, 256]), mindspore.float16)
436        >>> cusmatmulcubefraczleftcast = ops.CusMatMulCubeFraczLeftCast()
437        >>> output = cusmatmulcubefraczleftcast(input_x, input_y)
438    """
439
440    @prim_attr_register
441    def __init__(self):
442        """Initialize CusMatMulCubeFraczLeftCast"""
443        self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y'])
444        from mindspore.ops._op_impl._custom_op.matmul_cube_fracz_left_cast_impl import cus_matmul_cube_fraczleftcast
445
446    def infer_shape(self, data1_shape, data2_shape):
447        return data2_shape
448
449    def infer_dtype(self, data1_dtype, data2_dtype):
450        return mstype.float16
451
452
453class ThorIm2Col(PrimitiveWithInfer):
454    """
455    extracts image paths from image.
456
457    The rank of input_x1 must be `4`, data_format is "NCHW".
458
459    Inputs:
460        - **input_x1** (Tensor) - The feature map.
461          The shape of the tensor is :math:`(N, C, H, W)`.
462    Outputs:
463        Tensor.
464    Examples:
465        >>> input_x = Tensor(np.random.rand(32, 3, 224, 224).astype(np.float16))
466        >>> img2col = ops.CusMatMulCubeDenseLeft(kernel_size=7, pad=3, stride=2)
467        >>> output = img2col(input_x)
468    """
469
470    @prim_attr_register
471    def __init__(self,
472                 kernel_size,
473                 pad_mode="valid",
474                 pad=0,
475                 stride=1,
476                 dilation=1):
477        """Initialize ThorIm2Col"""
478        self.init_prim_io_names(inputs=['x'], outputs=['output'])
479        self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name)
480        self.add_prim_attr('kernel_size', self.kernel_size)
481        self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True)
482        self.add_prim_attr('stride', self.stride)
483        self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True)
484        self.add_prim_attr('dilation', self.dilation)
485        validator.check_value_type('pad', pad, (int,), self.name)
486        self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name)
487        self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name)
488        if self.pad_mode == 'pad':
489            validator.check_non_negative_int(self.pad, 'pad', self.name)
490        self.add_prim_attr('data_format', "NCHW")
491
492    def infer_shape(self, x_shape):
493        validator.check_equal_int(len(x_shape), 4, "x rank", self.name)
494        kernel_size_h = self.kernel_size[0]
495        kernel_size_w = self.kernel_size[1]
496        stride_h = self.stride[2]
497        stride_w = self.stride[3]
498        dilation_h = self.dilation[2]
499        dilation_w = self.dilation[3]
500        if self.pad_mode == "valid":
501            h_out = math.ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h)
502            w_out = math.ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w)
503            pad_top, pad_bottom, pad_left, pad_right = 0, 0, 0, 0
504        elif self.pad_mode == "same":
505            h_out = math.ceil(x_shape[2] / stride_h)
506            w_out = math.ceil(x_shape[3] / stride_w)
507            pad_needed_h = max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2])
508            pad_top = math.floor(pad_needed_h / 2)
509            pad_bottom = pad_needed_h - pad_top
510            pad_needed_w = max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3])
511            pad_left = math.floor(pad_needed_w / 2)
512            pad_right = pad_needed_w - pad_left
513        elif self.pad_mode == 'pad':
514            pad_top, pad_bottom, pad_left, pad_right = self.pad, self.pad, self.pad, self.pad
515            h_out = 1 + (x_shape[2] + 2 * self.pad - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h
516            w_out = 1 + (x_shape[3] + 2 * self.pad - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w
517            h_out = math.floor(h_out)
518            w_out = math.floor(w_out)
519        self.pad_list = [pad_top, pad_bottom, pad_left, pad_right]
520        self.add_prim_attr('pad_list', (pad_top, pad_bottom, pad_left, pad_right))
521        batch_size = x_shape[0]
522        channel = x_shape[1]
523        k_h = kernel_size_h
524        k_w = kernel_size_w
525        out_shape = [channel, k_h, k_w, batch_size, h_out, w_out]
526        return out_shape
527
528    def infer_dtype(self, x_dtype):
529        valid_dtypes = [mstype.float16, mstype.float32]
530        validator.check_tensor_dtype_valid('x', x_dtype, valid_dtypes, self.name)
531        return x_dtype
532
533
534class NewIm2Col(PrimitiveWithInfer):
535    """
536    extracts image paths from image by using TBE.
537
538    The rank of input_x1 must be `4`, data_format is "NCHW".
539
540    Inputs:
541        - **input_x1** (Tensor) - The feature map.
542          The shape of the tensor is :math:`(N, C, H, W)`.
543    Outputs:
544        Tensor. The shape of the tensor is :math:`(N, H, W, C)`.
545
546    Examples:
547        >>> input_x = Tensor(np.random.rand(32, 3, 224, 224).astype(np.float16))
548        >>> im2col = ops.NewIm2Col(ksizes=(7,7), strides=2)
549        >>> output = im2col(input_x)
550    """
551
552    @prim_attr_register
553    def __init__(self,
554                 ksizes,
555                 padding_mode="SAME",
556                 strides=1,
557                 dilations=1,
558                 pads=0):
559        """Initialize NewIm2Col"""
560        self.init_prim_io_names(inputs=['x'], outputs=['output'])
561        self.ksizes = ksizes
562        self.strides = strides
563        self.add_prim_attr('ksizes', self.ksizes)
564        self.add_prim_attr('strides', self.strides)
565        self.dilations = dilations
566        self.add_prim_attr('dilations', self.dilations)
567        self.padding_mode = validator.check_string(padding_mode, ['VALID', 'SAME'], 'padding_mode', self.name)
568        self.add_prim_attr('data_format', "NCHW")
569        self.pads = pads
570
571    def infer_shape(self, x_shape):
572        "infer shape"
573        validator.check_equal_int(len(x_shape), 4, "x rank", self.name)
574        kernel_size_h = self.ksizes[0]
575        kernel_size_w = self.ksizes[1]
576        stride_h = self.strides
577        stride_w = self.strides
578        dilation_h = self.dilations
579        dilation_w = self.dilations
580        if self.padding_mode == "VALID":
581            h_out = math.ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h)
582            w_out = math.ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w)
583            pad_top, pad_bottom, pad_left, pad_right = 0, 0, 0, 0
584        elif self.padding_mode == "SAME":
585            h_out = math.ceil(x_shape[2] / stride_h)
586            w_out = math.ceil(x_shape[3] / stride_w)
587            pad_needed_h = max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2])
588            pad_top = math.floor(pad_needed_h / 2)
589            pad_bottom = pad_needed_h - pad_top
590            pad_needed_w = max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3])
591            pad_left = math.floor(pad_needed_w / 2)
592            pad_right = pad_needed_w - pad_left
593        self.pad_list = [pad_top, pad_bottom, pad_left, pad_right]
594        self.add_prim_attr('pad_list', (pad_top, pad_bottom, pad_left, pad_right))
595        batch_size = x_shape[0]
596        channel = x_shape[1]
597        k_h = kernel_size_h
598        k_w = kernel_size_w
599        out_shape = [batch_size, h_out, w_out, channel * k_h * k_w]
600        return out_shape
601
602    def infer_dtype(self, x_dtype):
603        "infer dtype"
604        valid_dtypes = [mstype.float16, mstype.int8]
605        validator.check_tensor_dtype_valid('x', x_dtype, valid_dtypes, self.name)
606        return x_dtype
607
608
609class LoadIm2Col(PrimitiveWithInfer):
610    """
611    extracts image patches from image.
612
613    The rank of input_x1 must be `4`, data_format is "NCHW".
614    Only supports when C is divisible by 16.
615
616    Inputs:
617        - **input_x1** (Tensor) - The feature map.
618          The shape of the tensor is :math:`(N, C, H, W)`.
619    Outputs:
620        Tensor.
621    Examples:
622        >>> input_x = Tensor(np.random.rand(32, 16, 224, 224).astype(np.float16))
623        >>> img2col = ops.LoadIm2Col(kernel_size=(7,7), stride=(2,2))
624        >>> output = img2col(input_x)
625    """
626
627    @prim_attr_register
628    def __init__(self,
629                 ksizes,
630                 strides,
631                 pad_mode="same",
632                 dilates=(1, 1, 1, 1)):
633        """Initialize LoadIm2Col"""
634
635        self.init_prim_io_names(inputs=['x1'], outputs=['y'])
636        self.ksizes = ksizes
637        self.strides = strides
638        self.pad_mode = validator.check_string(pad_mode, ['same'], 'pad_mode', self.name)
639        self.dilation = dilates
640
641    def infer_shape(self, data1_shape):
642        bs, c, h, w = data1_shape
643        stride_h, stride_w = self.strides
644        k_w, k_h = self.ksizes
645        h_out = math.ceil(h / stride_h)
646        w_out = math.ceil(w / stride_w)
647        m = h_out * w_out
648        if m % 16 != 0:
649            shape = [(bs * m) // 16, (c * k_h * k_w) // 16, 16, 16]
650        else:
651            shape = [bs, m // 16, (c * k_h * k_w) // 16, 16, 16]
652        return shape
653
654    def infer_dtype(self, data1_dtype):
655        return data1_dtype
656
657
658class UpdateThorGradient(PrimitiveWithInfer):
659    """
660    Updates Thor Gradient with Approximate Fisher info matrix(for GPU backend).
661
662    The rank of input_x1 must be `3`, which indicates the A matrix.
663    The rank of input_x2 must be `2`, which indicates the 1st-order gradient.
664    The rank of input_x3 must be `4`, which indicates the G matrix.
665
666    Inputs:
667        - **input_x1** (Tensor) - The first input is the diag part of the cov matrix of feature map.
668          Supported dtype [float32].
669        - **input_x2** (Tensor) - The second input is the corresponding 1st-order grad. Supported dtype [float32].
670        - **input_x3** (Tensor) - The third input is the diag part of the cov matrix of dout.
671          Supported dtype [float32].
672
673    Outputs:
674        Tensor, the shape is the same as the shape of input_x2, it will be used to update the weights.
675
676    Examples:
677        >>> input_x1 = Tensor(np.random.rand(16, 128, 128).astype(np.float32))
678        >>> input_x2 = Tensor(np.random.rand(2048, 1024).astype(np.float32))
679        >>> temp_x3 = np.random.rand(8, 128, 128).astype(np.float32)
680        >>> input_x3 = np.zeros(16,8,128,128).astype(np.float32)
681        >>> for i in range(16):
682        ...     input_x3[i,:,:,:] = temp_x3
683        >>> input_x3 = Tensor(input_x3)
684        >>> update_thor_gradient = ops.UpdateThorGradient(split_dim=128)
685        >>> output = update_thor_gradient(input_x1, input_x2, input_x3)
686    """
687
688    @prim_attr_register
689    def __init__(self, split_dim=1):
690        """Initialize UpdateThorGradient"""
691        self.init_prim_io_names(inputs=['x1', 'x2', 'x3'], outputs=['y'])
692        self.split_dim = split_dim
693        self.add_prim_attr('split_dim', self.split_dim)
694
695    def infer_shape(self, x1_shape, x2_shape, x3_shape):
696        return x2_shape
697
698    def infer_dtype(self, x1_dtype, x2_dtype, x3_dtype):
699        validator.check_tensors_dtypes_same_and_valid(
700            {'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'x3_dtype': x3_dtype},
701            [mstype.float32], self.name)
702        return x2_dtype
703
704
705class _Cholesky(PrimitiveWithInfer):
706    """
707    Inner API for _Cholesky base class.
708    """
709
710    @prim_attr_register
711    def __init__(self, lower=False, clean=True, split_dim=0):
712        self.init_prim_io_names(inputs=['x1'], outputs=['y'])
713        self.lower = validator.check_value_type("lower", lower, [bool], self.lower)
714        self.clean = validator.check_value_type("clean", clean, [bool], self.clean)
715        self.lower = lower
716        self.add_prim_attr('lower', self.lower)
717        self.clean = clean
718        self.add_prim_attr('clean', self.clean)
719        self.split_dim = split_dim
720        self.add_prim_attr('split_dim', self.split_dim)
721
722    def infer_shape(self, x1_shape):
723        if self.split_dim != 0:
724            height = x1_shape[0]
725            width = x1_shape[1]
726            if height <= self.split_dim:
727                out_shape = [1, height, width]
728            else:
729                batch = height // self.split_dim
730                if height != batch * self.split_dim:
731                    batch += 1
732                out_shape = [batch, self.split_dim, self.split_dim]
733        else:
734            out_shape = x1_shape
735        return out_shape
736
737    def infer_dtype(self, x1_dtype):
738        validator.check_tensor_dtype_valid('x1', x1_dtype, [mstype.float32, mstype.float64], self.name)
739        return x1_dtype
740
741
742class Cholesky(_Cholesky):
743    """
744    Inner API for positive-definite matrix Cholesky decomposition GPU backend.
745    """
746
747
748class CholeskyTrsm(_Cholesky):
749    """
750    Inner API for resnet50 THOR GPU backend.
751    """
752
753
754class DetTriangle(PrimitiveWithInfer):
755    """
756    Calculate the determinant of triangle matrices.
757
758    Args:
759        fill_mode (tuple): The target shape to broadcast.
760
761    Inputs:
762        - **input_x** (Tensor) - The input tensor.
763
764    Outputs:
765        Tensor, with the given `shape` and the same data type as `input_x`.
766
767    Examples:
768        >>> shape = (2, 3)
769        >>> input_x = Tensor(np.array([1, 2, 3]).astype(np.float32))
770        >>> broadcast_to = P.BroadcastTo(shape)
771        >>> broadcast_to(input_x)
772        [[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]
773    """
774
775    @prim_attr_register
776    def __init__(self, fill_mode=0):
777        self.init_prim_io_names(inputs=['x1'], outputs=['y'])
778        self.fill_mode = fill_mode
779        self.add_prim_attr('fill_mode', self.fill_mode)
780
781    def infer_shape(self, x1_shape):
782        out_shape = x1_shape
783        del out_shape[-2:]
784        return out_shape
785
786    def infer_dtype(self, x1_dtype):
787        validator.check_tensor_dtype_valid('x1', x1_dtype, [mstype.float32], self.name)
788        return x1_dtype
789
790
791class ProdForceSeA(PrimitiveWithInfer):
792    """
793    ProdForceSeA.
794    """
795
796    @prim_attr_register
797    def __init__(self, natoms=192):
798        self.init_prim_io_names(inputs=['net_deriv_tensor', "in_deriv_tensor", "nlist_tensor"], outputs=['y'])
799        self.natoms = natoms
800        self.add_prim_attr('natoms', self.natoms)
801
802    def infer_shape(self, x1_shape, x2_shape, x3_shape):
803        out_shape = [x3_shape[0], x3_shape[1], 3]
804        return out_shape
805
806    def infer_dtype(self, x1_dtype, x2_dtype, x3_dtype):
807        return x1_dtype
808