• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""math Operations."""
16from itertools import zip_longest
17from collections import deque
18import numpy as np
19from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
20from mindspore.common import dtype as mstype
21from mindspore._checkparam import Validator as validator
22from mindspore.ops.primitive import constexpr
23from mindspore.ops import functional as F
24from .. import operations as P
25
26# count_nonzero
27
28
29@constexpr
30def _check_validate_axis(axis, name):
31    if isinstance(axis, (tuple, list)):
32        for idx, item in enumerate(axis):
33            validator.check_value_type("axis[%d]" % idx, item, [int], name)
34    axis = validator.check_value_type('axis', axis, [int, tuple, list], name)
35    return axis
36
37
38@constexpr
39def _check_validate_keepdims(keep_dims, name):
40    keep_dims = validator.check_value_type('keep_dims', keep_dims, [bool], name)
41    return keep_dims
42
43
44def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32):
45    r"""
46    Count number of nonzero elements across axis of input tensor
47
48    Args:
49        x (Tensor): Input data is used to count non-zero numbers.
50          :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
51        axis (Union[int, tuple(int), list(int)]): The dimensions to reduce. Only constant value is allowed.
52                                                  Default: (), reduce all dimensions.
53        keep_dims (bool): If true, keep these reduced dimensions and the length is 1.
54                          If false, don't keep these dimensions. Default: False.
55        dtype (Union[Number, mindspore.bool\_]): The data type of the output tensor. Only constant value is allowed.
56                                             Default: mindspore.int32
57
58    Returns:
59          Tensor, number of nonzero element. The data type is `dtype`.
60
61    Supported Platforms:
62        ``Ascend`` ``GPU`` ``CPU``
63
64    Examples:
65        >>> # case 1: each value specified.
66        >>> x = Tensor(np.array([[0, 1, 0], [1, 1, 0]]).astype(np.float32))
67        >>> nonzero_num = ops.count_nonzero(x=x, axis=[0, 1], keep_dims=True, dtype=mindspore.int32)
68        >>> print(nonzero_num)
69        [[3]]
70        >>> # case 2: all value is default.
71        >>> nonzero_num = ops.count_nonzero(x=x)
72        >>> print(nonzero_num)
73        3
74        >>> # case 3: axis value was specified 0.
75        >>> nonzero_num = ops.count_nonzero(x=x, axis=[0,])
76        >>> print(nonzero_num)
77        [1 2 0]
78        >>> # case 4: axis value was specified 1.
79        >>> nonzero_num = ops.count_nonzero(x=x, axis=[1,])
80        >>> print(nonzero_num)
81        [1 2]
82        >>> # case 5: keep_dims value was specified.
83        >>> nonzero_num = ops.count_nonzero(x=x,  keep_dims=True)
84        >>> print(nonzero_num)
85        [[3]]
86        >>> # case 6: keep_dims and axis value was specified.
87        >>> nonzero_num = ops.count_nonzero(x=x, axis=[0,], keep_dims=True)
88        >>> print(nonzero_num)
89        [[1 2 0]]
90    """
91
92    const_utils.check_type_valid(F.dtype(x), mstype.number_type, 'input x')
93    axis = _check_validate_axis(axis, "count_nonzero")
94    keep_dims = _check_validate_keepdims(keep_dims, "count_nonzero")
95    const_utils.check_type_valid(dtype, mstype.number_type + (mstype.bool_,), 'dtype')
96
97    not_equal = P.NotEqual()
98    cast = P.Cast()
99    reduce_sum = P.ReduceSum(keep_dims)
100    nonzero_bool = not_equal(x, 0)
101    # ReduceSum only support float16 or float32 tensor.
102    nonzero_val = cast(nonzero_bool, mstype.float32)
103    nonzero_num = cast(reduce_sum(nonzero_val, axis), dtype)
104
105    return nonzero_num
106
107# tensor dot
108
109
110@constexpr
111def _int_to_tuple_conv(axes):
112    """
113    Converts ints to tuples in input axes, expected by most validation checks.
114    """
115    for x in [0, 1]:
116        if isinstance(axes[x], int):
117            axes[x] = (axes[x],)
118    return axes
119
120
121@constexpr
122def _check_axes(axes, prim_name=None):
123    """
124    Check for validity and type of axes passed to function.
125    """
126    msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
127    validator.check_value_type('axes', axes, [int, tuple, list], "tensor dot")
128    if not isinstance(axes, int):
129        axes = list(axes)  # to avoid immutability issues
130        if len(axes) != 2:
131            raise ValueError(f"{msg_prefix} dimension of 'axes' should be 2, but got 'axes': {axes}.")
132        axes = _int_to_tuple_conv(axes)  # convert before length checks
133        if len(axes[0]) != len(axes[1]):
134            raise ValueError(f"{msg_prefix} first and second dim of 'axes' have to be the same size/length, "
135                             f"but got 'axes': {axes}.")
136        if len(axes[0]) != len(set(axes[0])) or len(axes[1]) != len(set(axes[1])):
137            raise ValueError(f"{msg_prefix} 'axes' cannot have duplicating values, but got {axes}.")
138    return axes
139
140
141@constexpr
142def _typecheck_input(x1_type, x2_type, prim_name=None):
143    """
144    Check input tensor types to be valid and confirm they are the same type.
145    """
146    msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
147    const_utils.check_type_valid(x1_type, [mstype.float32, mstype.float16], 'x1')
148    const_utils.check_type_valid(x2_type, [mstype.float32, mstype.float16], 'x2')
149    if x1_type != x2_type:
150        raise TypeError(f"{msg_prefix} inputs must be the same type, but got x1_type: {x1_type} "
151                        f"and x2_type: {x2_type}.")
152
153
154@constexpr
155def _axes_int_check(x1_shape, x2_shape, axes, prim_name=None):
156    """
157    Convert from single int axes to 2d tuple if required
158    """
159    msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
160    if isinstance(axes, int):
161        if axes < 0:
162            raise ValueError(f"{msg_prefix} axes must be at least 0, but got {axes}.")
163        if axes == 0:
164            # outer product, no input validation required
165            return [], []
166        if axes > len(x1_shape) or axes > len(x2_shape):
167            raise ValueError(f"{msg_prefix} axes cannot be greater than the length of x1_shape and x2_shape, "
168                             f"but got axes: {axes}, x1_shape: {x1_shape}, x2_shape: {x2_shape}.")
169        x1_ind = tuple(range(len(x1_shape))[-1 * axes:])
170        x2_ind = tuple(range(len(x2_shape))[:axes])
171        axes = tuple((x1_ind, x2_ind))
172        axes = _int_to_tuple_conv(axes)
173    return axes
174
175
176@constexpr
177def _validate_axes(x1_shape, x2_shape, axes, prim_name=None):
178    """
179    Checks for axes having the correct length according to input, for any value in axis
180    being out of range with given shape and also checking for compatible axes values
181    with given inputs.
182    """
183    msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
184    shapes = [x1_shape, x2_shape]
185
186    # axis length check
187    for ix_input, x_axes in enumerate(axes):
188        axes_len = len(x_axes)
189        shape_dim_len = len(shapes[ix_input])
190        if axes_len > shape_dim_len:
191            raise ValueError(f"{msg_prefix} length of x_axes should be less than or equal to {shape_dim_len}, "
192                             f"but got 'len(x_axes)': {axes_len}.")
193
194    # axis values range check
195    for ix_input, x_axes in enumerate(axes):
196        comp_shape = shapes[ix_input]
197        max_val = len(comp_shape) - 1
198        min_val = -1 * len(comp_shape)
199        for _, x_value in enumerate(x_axes):
200            if not min_val <= x_value <= max_val:
201                raise ValueError(f"{msg_prefix} value in axes should be in range: [{min_val}, {max_val}], "
202                                 f"but got {x_value}.")
203
204    # check axis value with input shape - both ways for axis valid
205    invalid_a = False
206    invalid_b = False
207    for i in range(len(axes[0])):  # sizes already validated
208        if x1_shape[axes[0][i]] != x2_shape[axes[1][i]]:
209            invalid_a = True
210        if x1_shape[axes[0][i]] != x2_shape[axes[1][len(axes[0])-1-i]]:
211            invalid_b = True
212    if invalid_a and invalid_b:
213        raise ValueError(f"{msg_prefix} 'i' should exist such that 'x1_shape[axes[0][i]]' is equal to "
214                         f"'x2_shape[axes[1][i]]' or 'x2_shape[axes[1][len(axes[0])-1-i]]', but got "
215                         f"x1_shape: {x1_shape}, x2_shape: {x2_shape}, axes: {axes}.")
216
217
218@constexpr
219def _calc_new_shape(shape, axes, position=0):
220    """
221    Calculate transpose and reshape parameters for input transformations,
222    'position' refers to whether tensor is first or second in the op.
223    """
224    contraction_axes = tuple(i if i >= 0 else i + len(shape) for i in axes[position])
225    prod_contraction = int(np.prod([shape[i] for i in contraction_axes]))
226    free_axes = tuple(i for i in range(len(shape)) if i not in contraction_axes)
227    free_dims = tuple(shape[i] for i in free_axes)
228    prod_free = int(np.prod(free_dims))
229
230    transpose_perm = contraction_axes + free_axes if position else free_axes + contraction_axes
231    new_shape = (prod_contraction, prod_free) if position else (prod_free, prod_contraction)
232    return new_shape, transpose_perm, free_dims
233
234
235def tensor_dot(x1, x2, axes, prim_name='tensor_dot'):
236    """
237    Computation of Tensor contraction on arbitrary axes between tensors `a` and `b`.
238
239    Contraction allows for the summation of products of elements of `a` and `b` on specified axes.
240    The same number of axes must be specified for both x1 and x2, and values must be within range
241    of number of dims of both `a` and `b`.
242
243    Selected dims in both inputs must also match.
244
245    axes = 0 leads to outer product
246    axes = 1 leads to normal matrix multiplication when inputs both 2D.
247    axes = 1 is the same as axes = ((1,),(0,) where both `a` and `b` are 2D.
248    axes = 2 is the same as axes = ((1,2),(0,1)) where both `a` and `b` are 3D.
249
250    Inputs:
251        - **x1** (Tensor) - First tensor in tensor_dot with datatype float16 or float32
252        - **x2** (Tensor) - Second tensor in tensor_dot with datatype float16 or float32
253        - **axes** (Union[int, tuple(int), tuple(tuple(int)), list(list(int))]) - Single value or
254          tuple/list of length 2 with dimensions specified for `a` and `b` each. If single value `N` passed,
255          automatically picks up last N dims from `a` input shape and first N dims from `b` input shape in order
256          as axes for each respectively.
257
258    Outputs:
259        Tensor, the shape of the output tensor is :math:`(N + M)`. Where :math:`N` and :math:`M` are the free axes not
260        contracted in both inputs
261
262    Raises:
263        TypeError: If `x1` or `x2` is not a Tensor.
264        TypeError: If `axes` is not one of the following: int, tuple, list.
265
266    Supported Platforms:
267        ``Ascend`` ``GPU`` ``CPU``
268
269    Examples:
270        >>> input_x1 = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32)
271        >>> input_x2 = Tensor(np.ones(shape=[3, 1, 2]), mindspore.float32)
272        >>> output = ops.tensor_dot(input_x1, input_x2, ((0,1),(1,2)))
273        >>> print(output)
274        [[2. 2. 2]
275         [2. 2. 2]
276         [2. 2. 2]]
277    """
278    shape_op = P.Shape()
279    reshape_op = P.Reshape()
280    transpose_op = P.Transpose()
281    matmul_op = P.MatMul(False, False)
282    # input validity checks
283    x1_shape = shape_op(x1)
284    x2_shape = shape_op(x2)
285    x1_type = F.dtype(x1)
286    x2_type = F.dtype(x2)
287    axes = _check_axes(axes, prim_name)
288    _typecheck_input(x1_type, x2_type, prim_name)
289    # input compatibility check & axes format update
290    axes = _axes_int_check(x1_shape, x2_shape, axes, prim_name)
291    _validate_axes(x1_shape, x2_shape, axes, prim_name)
292    x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape(x1_shape, axes, 0)
293    x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape(x2_shape, axes, 1)
294    output_shape = x1_ret + x2_ret  # combine free axes from both inputs
295    # run tensor_dot op
296    x1_transposed = transpose_op(x1, x1_transpose_fwd)
297    x2_transposed = transpose_op(x2, x2_transpose_fwd)
298    x1_reshaped = reshape_op(x1_transposed, x1_reshape_fwd)
299    x2_reshaped = reshape_op(x2_transposed, x2_reshape_fwd)
300    mul_result = matmul_op(x1_reshaped, x2_reshaped)
301    final_result = reshape_op(mul_result, output_shape)
302    return final_result
303
304
305@constexpr
306def _check_invalid_input(x1_shape, x2_shape, prim_name=None):
307    msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
308    if len(x1_shape) < 2 or len(x2_shape) < 2:
309        raise ValueError(f"{msg_prefix} inputs x1, x2 should have 'dimension >= 2',"
310                         f"but got 'len(x1_shape)': ({len(x1_shape)}) and 'len(x2_shape)': ({len(x2_shape)}).")
311
312
313@constexpr
314def _typecheck_input_dot(x1_type, x2_type, prim_name=None):
315    """
316    Check input tensor types to be valid and confirm they are the same type for dot and batch dot ops.
317    """
318    msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
319    const_utils.check_type_valid(x1_type, [mstype.float16, mstype.float32], 'x1')
320    const_utils.check_type_valid(x2_type, [mstype.float16, mstype.float32], 'x2')
321    if x1_type != x2_type:
322        raise TypeError(f"{msg_prefix} inputs must be the same type, but got "
323                        f"x1_type: {x1_type} and x2_type: {x2_type}.")
324
325
326@constexpr
327def _get_transpose_shape(x2_shape):
328    x2_shape_range = tuple(range(len(x2_shape)))
329    x2_shape_transpose = x2_shape_range[-2:-1] + x2_shape_range[:-2] + x2_shape_range[-1:]
330    return x2_shape_transpose
331
332
333def dot(x1, x2, prim_name=None):
334    """
335    Computation a dot product between samples in two tensors.
336
337    Inputs:
338        - **x1** (Tensor) - First tensor in Dot op with datatype float16 or float32
339          The rank must be greater than or equal to 2.
340        - **x2** (Tensor) - Second tensor in Dot op with datatype float16 or float32
341          The rank must be greater than or equal to 2.
342
343    Outputs:
344        Tensor, dot product of x1 and x2.
345
346    Raises:
347        TypeError: If type of x1 and x2 are not the same.
348        TypeError: If dtype of x1 or x2 is not float16 or float32.
349        ValueError: If rank of x1 or x2 less than 2.
350
351    Supported Platforms:
352        ``Ascend`` ``GPU`` ``CPU``
353
354    Examples:
355        >>> input_x1 = Tensor(np.ones(shape=[2, 3]), mindspore.float32)
356        >>> input_x2 = Tensor(np.ones(shape=[1, 3, 2]), mindspore.float32)
357        >>> output = ops.dot(input_x1, input_x2)
358        >>> print(output)
359        [[[3. 3.]]
360         [[3. 3.]]]
361        >>> print(output.shape)
362        (2, 1, 2)
363        >>> input_x1 = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32)
364        >>> input_x2 = Tensor(np.ones(shape=[1, 3, 2]), mindspore.float32)
365        >>> output = ops.dot(input_x1, input_x2)
366        >>> print(output)
367        [[[[3. 3.]]
368          [[3. 3.]]]]
369        >>> print(output.shape)
370        (1, 2, 1, 2)
371        >>> input_x1 = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32)
372        >>> input_x2 = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32)
373        >>> output = ops.dot(input_x1, input_x2)
374        >>> print(output)
375        [[[[3. 3.]
376           [3. 3.]]
377          [[3. 3.]
378           [3. 3.]]]]
379        >>> print(output.shape)
380        (1, 2, 2, 2)
381        >>> input_x1 = Tensor(np.ones(shape=[3, 2, 3]), mindspore.float32)
382        >>> input_x2 = Tensor(np.ones(shape=[2, 1, 3, 2]), mindspore.float32)
383        >>> output = ops.dot(input_x1, input_x2)
384        >>> print(output)
385        [[[[[3. 3.]]
386           [[3. 3.]]]
387          [[[3. 3.]]
388           [[3. 3.]]]]
389         [[[[3. 3.]]
390           [[3. 3.]]]
391          [[[3. 3.]]
392           [[3. 3.]]]]
393         [[[[3. 3.]]
394           [[3. 3.]]]
395          [[[3. 3.]]
396           [[3. 3.]]]]]
397        >>> print(output.shape)
398        (3, 2, 2, 1, 2)
399    """
400    shape_op = P.Shape()
401    reshape_op = P.Reshape()
402    transpose_op = P.Transpose()
403    matmul_op = P.MatMul(False, False)
404    x1_shape = shape_op(x1)
405    x2_shape = shape_op(x2)
406    x1_type = F.dtype(x1)
407    x2_type = F.dtype(x2)
408    _typecheck_input_dot(x1_type, x2_type, prim_name)
409    _check_invalid_input(x1_shape, x2_shape, prim_name)
410
411    if len(x1_shape) > 2 or len(x2_shape) > 2:
412        x2_shape_transpose = _get_transpose_shape(x2_shape)
413        x2_transpose = transpose_op(x2, x2_shape_transpose)
414        x1_reshape = reshape_op(x1, (-1, x1_shape[-1]))
415        x2_reshape = reshape_op(x2_transpose, (x2_shape[-2], -1))
416        mul_result = matmul_op(x1_reshape, x2_reshape)
417        return reshape_op(mul_result, x1_shape[:-1] + x2_shape[:-2] + x2_shape[-1:])
418    return matmul_op(x1, x2)
419
420
421@constexpr
422def _get_batch_size(x1_shape, x2_shape, prim_name=None):
423    """
424    Get batch sizes from two inputs
425    """
426    msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
427    if len(x1_shape) < 2 or len(x2_shape) < 2:
428        raise ValueError(f"{msg_prefix} inputs x1, x2 should have 'dimension >= 2', "
429                         f"but got 'len(x1_shape)': ({len(x1_shape)}) and 'len(x2_shape)': ({len(x2_shape)}).")
430    return x1_shape[0], x2_shape[0]
431
432
433@constexpr
434def _typecheck_input_batch_dot(x1_type, x2_type, prim_name=None):
435    """
436    Check input tensor types to be valid and confirm they are the same type for batch dot ops.
437    """
438    msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
439    const_utils.check_type_valid(x1_type, [mstype.float32], 'x1')
440    const_utils.check_type_valid(x2_type, [mstype.float32], 'x2')
441    if x1_type != x2_type:
442        raise TypeError(f"{msg_prefix} inputs must be the same type, but got x1_type: {x1_type} and "
443                        f"x2_type: {x2_type}.")
444
445
446@constexpr
447def _check_axes_for_batch_dot(x1_shape, x2_shape, axes, prim_name=None):
448    """
449    Check whether axes are valid and cast axes from tuple to list
450    """
451    msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
452    if axes is None:
453        if len(x2_shape) == 2:
454            axes = [len(x1_shape) - 1, len(x2_shape) - 1]
455        else:
456            axes = [len(x1_shape) - 1, len(x2_shape) - 2]
457
458    if isinstance(axes, (list, tuple)):
459        if 0 in axes:
460            raise ValueError(f"{msg_prefix} axes cannot contain 0, but got axes: {axes}.")
461        if len(axes) != 2:
462            raise ValueError(f"{msg_prefix} length of axes must be equal to 2, but got {len(axes)}.")
463        if isinstance(axes, tuple):
464            axes = list(axes)
465        validator.check_value_type('axes[0]', axes[0], [int], 'batch_dot')
466        validator.check_value_type('axes[1]', axes[1], [int], 'batch_dot')
467        # Reverse if axis < 0
468        if axes[0] < 0:
469            axes[0] += len(x1_shape)
470        if axes[1] < 0:
471            axes[1] += len(x2_shape)
472        validator.check_non_negative_int(axes[0], 'reversed axes[0]', 'batch_dot')
473        validator.check_non_negative_int(axes[1], 'reversed axes[1]', 'batch_dot')
474        if axes[0] > len(x1_shape) or axes[1] > len(x2_shape):
475            raise ValueError(f"{msg_prefix} axes[0] must be less than or equal to len(x1_shape), "
476                             f"and axes[1] must be less than or equal to len(x2_shape)."
477                             f"But got axes: {axes}, x1_shape: {x1_shape}, x2_shape: {x2_shape}.")
478    elif isinstance(axes, int):
479        if axes == 0:
480            raise ValueError(f"{msg_prefix} axes should not equal to 0, but got {axes}.")
481        if axes < 0:
482            axes = [axes + len(x1_shape), axes + len(x2_shape)]
483            validator.check_non_negative_int(axes[0], 'reversed axes', 'batch_dot')
484        elif axes > len(x1_shape) or axes > len(x2_shape):
485            raise ValueError(f"{msg_prefix} axes cannot be greater than the length of x1_shape and x2_shape, "
486                             f"but got axes: {axes}, x1_shape: {x1_shape}, x2_shape: {x2_shape}.")
487        else:
488            axes = [axes, axes]
489    else:
490        raise ValueError(f"{msg_prefix} type of axes must be one of those: int, tuple(int), list(int), "
491                         f"but got {type(axes).__name__}.")
492    return axes
493
494
495@constexpr
496def _calc_new_shape_batchdot(shape, axes, position=0):
497    """
498    Calculate transpose and reshape parameters for input transformations,
499    'position' refers to whether tensor is first or second in the op.
500    """
501    axis = axes[position]
502    contraction_axes = tuple([axis])
503    prod_contraction = int(np.prod([shape[i] for i in contraction_axes]))
504    free_axes = tuple(i for i in range(1, len(shape)) if i not in contraction_axes)
505    free_dims = tuple(shape[i] for i in free_axes)
506    prod_free = int(np.prod(free_dims))
507
508    transpose_perm = contraction_axes + free_axes if position else free_axes + contraction_axes
509    transpose_perm = tuple([0]) + transpose_perm
510    new_shape = (prod_contraction, prod_free) if position else (prod_free, prod_contraction)
511    new_shape = tuple([shape[0]]) + new_shape
512    return new_shape, transpose_perm, free_dims
513
514
515@constexpr
516def _check_batch_size(x1_batch_size, x2_batch_size, prim_name=None):
517    """
518    Check whether batch size of two inputs are the same
519    """
520    msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
521    if x1_batch_size != x2_batch_size:
522        raise ValueError(f"{msg_prefix} both inputs x1, x2 should have the same batch sizes, but got "
523                         f"x1_batch_size: {x1_batch_size} and x2_batch_size: {x2_batch_size}.")
524
525
526@constexpr
527def _get_output_shape(batch_size, x1_ret, x2_ret):
528    """
529    Compute output shape for batch dot
530    """
531    output_shape = tuple([batch_size]) + x1_ret + x2_ret
532    return output_shape
533
534
535def batch_dot(x1, x2, axes=None, prim_name=None):
536    """
537    Computation of batch dot product between samples in two tensors containing batch dims.
538
539    .. math::
540        output = x1[batch, :] * x2[batch, :]
541
542    Inputs:
543        - **x1** (Tensor) - First tensor in Batch Dot op with datatype float32 and the rank of `x1` must be greater
544          than or equal to 2.
545        - **x2** (Tensor) - Second tensor in Batch Dot op with datatype float32. The datatype of `x2` should
546          be same as `x1` and the rank of `x2` must be greater than or equal to 2.
547        - **axes** (Union[int, tuple(int), list(int)]) - Single value or tuple/list of length 2 with dimensions
548          specified for `a` and `b` each. If single value `N` passed, automatically picks up last N dims from
549          `a` input shape and last N dimensions from `b` input shape in order as axes for each respectively.
550          Default: None.
551
552    Outputs:
553        Tensor, batch dot product of `x1` and `x2`.For example: The Shape of output
554        for input `x1` shapes (batch, d1, axes, d2) and `x2` shapes (batch, d3, axes, d4) is (batch, d1, d2, d3, d4),
555        where d1 and d2 means any number.
556
557    Raises:
558        TypeError: If type of x1 and x2 are not the same.
559        TypeError: If dtype of x1 or x2 is not float32.
560        ValueError: If rank of x1 or x2 less than 2.
561        ValueError: If batch dim used in axes.
562        ValueError: If len(axes) less than 2.
563        ValueError: If axes is not one of those: None, int, (int, int).
564        ValueError: If axes reversed from negative int is too low for dimensions of input arrays.
565        ValueError: If axes value is too high for dimensions of input arrays.
566        ValueError: If batch size of x1 and x2 are not the same.
567
568    Supported Platforms:
569        ``Ascend`` ``GPU`` ``CPU``
570
571    Examples:
572        >>> x1 = Tensor(np.ones(shape=[2, 2, 3]), mindspore.float32)
573        >>> x2 = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32)
574        >>> axes = (-1, -2)
575        >>> output = ops.batch_dot(x1, x2, axes)
576        >>> print(output)
577        [[[3. 3.]
578          [3. 3.]]
579         [[3. 3.]
580          [3. 3.]]]
581        >>> x1 = Tensor(np.ones(shape=[2, 2]), mindspore.float32)
582        >>> x2 = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32)
583        >>> axes = (1, 2)
584        >>> output = ops.batch_dot(x1, x2, axes)
585        >>> print(output)
586        [[2. 2. 2.]
587         [2. 2. 2.]]
588        >>> print(output.shape)
589        (2, 3)
590        >>> x1 = Tensor(np.ones(shape=[6, 2, 3, 4]), mindspore.float32)
591        >>> x2 = Tensor(np.ones(shape=[6, 5, 4, 8]), mindspore.float32)
592        >>> output = ops.batch_dot(x1, x2)
593        >>> print(output.shape)
594        (6, 2, 3, 5, 8)
595        >>> x1 = Tensor(np.ones(shape=[2, 2, 4]), mindspore.float32)
596        >>> x2 = Tensor(np.ones(shape=[2, 5, 4, 5]), mindspore.float32)
597        >>> output = ops.batch_dot(x1, x2)
598        >>> print(output.shape)
599        (2, 2, 5, 5)
600
601    """
602
603    transpose_op = P.Transpose()
604    batch_matmul_op = P.BatchMatMul()
605    squeeze_one_op = P.Squeeze(1)
606    squeeze_minus_one_op = P.Squeeze(-1)
607    # input validity checks
608    x1_shape = F.shape(x1)
609    x2_shape = F.shape(x2)
610    x1_dim_num = len(x1_shape)
611    x2_dim_num = len(x2_shape)
612    x1_type = F.dtype(x1)
613    x2_type = F.dtype(x2)
614
615    x1_batch_size, x2_batch_size = _get_batch_size(x1_shape, x2_shape, prim_name)
616
617    _typecheck_input_batch_dot(x1_type, x2_type, prim_name)
618    _check_batch_size(x1_batch_size, x2_batch_size, prim_name)
619    axes = _check_axes_for_batch_dot(x1_shape, x2_shape, axes, prim_name)
620
621    if x1_dim_num == 2:
622        x1 = F.expand_dims(x1, 1)
623        axes[0] += 1
624    if x2_dim_num == 2:
625        x2 = F.expand_dims(x2, 2)
626
627    x1_shape = F.shape(x1)
628    x2_shape = F.shape(x2)
629
630    x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape_batchdot(x1_shape, axes, 0)
631    x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape_batchdot(x2_shape, axes, 1)
632    output_shape = _get_output_shape(x1_batch_size, x1_ret, x2_ret)
633
634    x1_transposed = transpose_op(x1, x1_transpose_fwd)
635    x2_transposed = transpose_op(x2, x2_transpose_fwd)
636    x1_reshaped = F.reshape(x1_transposed, x1_reshape_fwd)
637    x2_reshaped = F.reshape(x2_transposed, x2_reshape_fwd)
638
639    # Batch matmal op part
640    mul_result = batch_matmul_op(x1_reshaped, x2_reshaped)
641
642    final_result = F.reshape(mul_result, output_shape)
643
644    # if the original dims are expanded, restore them from 3 to 2
645    if x1_dim_num == 2:
646        final_result = squeeze_one_op(final_result)
647    elif x2_dim_num == 2:
648        final_result = squeeze_minus_one_op(final_result)
649
650    return final_result
651
652
653@constexpr
654def _check_same_type(dtype1, dtype2):
655    return dtype1 == dtype2
656
657
658@constexpr
659def _max(*args):
660    """Returns the maximum value."""
661    return max(*args)
662
663
664@constexpr
665def _min(*args):
666    """Returns the minimum value."""
667    return min(*args)
668
669
670@constexpr
671def _infer_shape_rem(shape1, shape2, ndim1, ndim2, transpose_b):
672    """Infers the shape of the last two dimensions after performing matmul."""
673    shape_rem = []
674    if ndim1 >= 2:
675        shape_rem.append(shape1[-2])
676    if transpose_b:
677        if ndim2 >= 2:
678            shape_rem.append(shape2[-2])
679    else:
680        if ndim1 >= 1:
681            shape_rem.append(shape2[-1])
682    return tuple(shape_rem)
683
684
685@constexpr
686def _check_matmul_shapes(shape1, shape2, prim_name=None):
687    """Checks shape1 and shape2 are valid to perform matmul, and returns output shape after broadcasting."""
688    msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
689    ndim1, ndim2 = len(shape1), len(shape2)
690    if ndim1 < 1 or ndim2 < 1:
691        raise ValueError(f"{msg_prefix} dimension of input operands must be at least 1, but got "
692                         f"the length of shape1: {ndim1}, the length of shape2: {ndim2}.")
693    if ndim2 >= 2 and shape1[-1] != shape2[-2]:
694        raise ValueError(f"{msg_prefix} shape1[-1] should be equal to shape2[-2] when the length of shape2 "
695                         f"is greater than or equal to 2, but got shape1[-1]: {shape1[-1]}, "
696                         f"shape2[-2]: {shape2[-2]}.")
697    shape_out = deque()
698    for items in zip_longest(reversed(shape1[:-2]), reversed(shape2[:-2]), fillvalue=1):
699        max_size = max(items)
700        if any(item not in (1, max_size) for item in items):
701            raise ValueError(f"{msg_prefix} operands could not be broadcast together with shape1 {shape1} and "
702                             f"shape2 {shape2}.")
703        shape_out.appendleft(max_size)
704    return tuple(shape_out)
705
706
707@constexpr
708def _tile_size(shape, out_shape, ndim):
709    """Returns tile_size such that shape*tile_size = out_shape"""
710    size = [1] * ndim
711    for idx, (i, j) in enumerate(zip(shape, out_shape)):
712        if i != j:
713            size[idx] = j
714    return tuple(size)
715
716
717@constexpr
718def _check_need_broadcast(shape1, shape2):
719    """Returns True if broadcast is necessary for batchmatmul."""
720    return shape1[:-2] != shape2[:-2]
721
722
723def _expand(x, ndim):
724    """Expand x to ndim from axis, which can be 0 or -1."""
725    while F.rank(x) < ndim:
726        x = F.expand_dims(x, 0)
727    return x
728
729
730def _broadcast_to(x, shape_cur, shape_to, ndim_to):
731    """Broadcasts x from shape_cur to shape_to."""
732    size = _tile_size(shape_cur, shape_to, ndim_to)
733    return F.tile(x, size)
734
735
736def matmul(x1, x2, dtype=None, prim_name=None):
737    """
738    Returns the matrix product of two arrays.
739
740    Note:
741        Numpy arguments `out`, `casting`, `order`, `subok`, `signature`, and `extobj` are
742        not supported.
743        On GPU, the supported dtypes are np.float16 and np.float32.
744        On CPU, the supported dtypes are np.float16 and np.float32.
745
746    Args:
747        x1 (Tensor): Input tensor, scalar not allowed.
748          The last dimension of `x1` must be the same size as the second last dimension of `x2`.
749          And the shape of x1 and x2 could be broadcast.
750        x2 (Tensor): Input tensor, scalar not allowed.
751          The last dimension of `x1` must be the same size as the second last dimension of `x2`.
752          And the shape of x1 and x2 could be broadcast.
753        dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
754            output Tensor.
755
756    Returns:
757        Tensor or scalar, the matrix product of the inputs. This is a scalar only
758        when both `x1`, `x2` are 1-d vectors.
759
760    Raises:
761        ValueError: If the last dimension of `x1` is not the same size as the
762            second-to-last dimension of `x2`, or if a scalar value is passed in.
763        ValueError: If the shape of `x1` and `x2` could not broadcast together。
764
765    Supported Platforms:
766        ``Ascend`` ``GPU`` ``CPU``
767
768    Examples:
769        >>> # case 1 : Reasonable application of broadcast mechanism
770        >>> x1 = Tensor(np.arange(2*3*4).reshape(2, 3, 4), mindspore.float32)
771        >>> x2 = Tensor(np.arange(4*5).reshape(4, 5), mindspore.float32)
772        >>> output = ops.matmul(x1, x2)
773        >>> print(output)
774        [[[  70.   76.   82.   88.   94.]
775        [ 190.  212.  234.  256.  278.]
776        [ 310.  348.  386.  424.  462.]]
777        [[ 430.  484.  538.  592.  646.]
778        [ 550.  620.  690.  760.  830.]
779        [ 670.  756.  842.  928. 1014.]]]
780        >>> print(output.shape)
781        (2, 3, 5)
782        >>> # case 2 : the rank of `x1` is 1
783        >>> x1 = Tensor(np.ones([1, 2]), mindspore.float32)
784        >>> x2 = Tensor(np.ones([2,]), mindspore.float32)
785        >>> output = ops.matmul(x1, x2)
786        >>> print(output)
787        [2.]
788        >>> print(output.shape)
789        (1,)
790    """
791    # performs type promotion
792    dtype1 = F.dtype(x1)
793    dtype2 = F.dtype(x2)
794    if not _check_same_type(dtype1, dtype2):
795        x1 = x1.astype(mstype.float32)
796        x2 = x2.astype(mstype.float32)
797
798    ndim1_orig, ndim2_orig = F.rank(x1), F.rank(x2)
799    shape1_orig, shape2_orig = F.shape(x1), F.shape(x2)
800    transpose_b = ndim2_orig == 1
801    shape_backbone = _check_matmul_shapes(shape1_orig, shape2_orig, prim_name)
802    # infers the shape of the output
803    shape_out = shape_backbone + _infer_shape_rem(shape1_orig, shape2_orig,
804                                                  ndim1_orig, ndim2_orig, transpose_b)
805
806    x1 = _expand(x1, 2)
807    x2 = _expand(x2, 2)
808    if F.rank(x2) == 2:
809        if F.rank(x1) > 2:
810            x1 = F.reshape(x1, (-1, shape1_orig[-1]))
811        res = P.MatMul(False, transpose_b)(x1, x2)
812    else:
813        # broadcasts x1.shape[:-2] with x2.shape[:-2]
814        ndim_aligned = _max(ndim1_orig, ndim2_orig)
815        x1 = _expand(x1, ndim_aligned)
816        x2 = _expand(x2, ndim_aligned)
817        shape1_aligned, shape2_aligned = F.shape(x1), F.shape(x2)
818        x1 = _broadcast_to(x1, shape1_aligned[:-2], shape_backbone, ndim_aligned)
819        x2 = _broadcast_to(x2, shape2_aligned[:-2], shape_backbone, ndim_aligned)
820        res = P.BatchMatMul(False, transpose_b)(x1, x2)
821
822    if dtype is not None:
823        res = res.astype(dtype)
824    return F.reshape(res, shape_out)
825