1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""math Operations.""" 16from itertools import zip_longest 17from collections import deque 18import numpy as np 19from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils 20from mindspore.common import dtype as mstype 21from mindspore._checkparam import Validator as validator 22from mindspore.ops.primitive import constexpr 23from mindspore.ops import functional as F 24from .. import operations as P 25 26# count_nonzero 27 28 29@constexpr 30def _check_validate_axis(axis, name): 31 if isinstance(axis, (tuple, list)): 32 for idx, item in enumerate(axis): 33 validator.check_value_type("axis[%d]" % idx, item, [int], name) 34 axis = validator.check_value_type('axis', axis, [int, tuple, list], name) 35 return axis 36 37 38@constexpr 39def _check_validate_keepdims(keep_dims, name): 40 keep_dims = validator.check_value_type('keep_dims', keep_dims, [bool], name) 41 return keep_dims 42 43 44def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32): 45 r""" 46 Count number of nonzero elements across axis of input tensor 47 48 Args: 49 x (Tensor): Input data is used to count non-zero numbers. 50 :math:`(N,*)` where :math:`*` means, any number of additional dimensions. 51 axis (Union[int, tuple(int), list(int)]): The dimensions to reduce. Only constant value is allowed. 52 Default: (), reduce all dimensions. 53 keep_dims (bool): If true, keep these reduced dimensions and the length is 1. 54 If false, don't keep these dimensions. Default: False. 55 dtype (Union[Number, mindspore.bool\_]): The data type of the output tensor. Only constant value is allowed. 56 Default: mindspore.int32 57 58 Returns: 59 Tensor, number of nonzero element. The data type is `dtype`. 60 61 Supported Platforms: 62 ``Ascend`` ``GPU`` ``CPU`` 63 64 Examples: 65 >>> # case 1: each value specified. 66 >>> x = Tensor(np.array([[0, 1, 0], [1, 1, 0]]).astype(np.float32)) 67 >>> nonzero_num = ops.count_nonzero(x=x, axis=[0, 1], keep_dims=True, dtype=mindspore.int32) 68 >>> print(nonzero_num) 69 [[3]] 70 >>> # case 2: all value is default. 71 >>> nonzero_num = ops.count_nonzero(x=x) 72 >>> print(nonzero_num) 73 3 74 >>> # case 3: axis value was specified 0. 75 >>> nonzero_num = ops.count_nonzero(x=x, axis=[0,]) 76 >>> print(nonzero_num) 77 [1 2 0] 78 >>> # case 4: axis value was specified 1. 79 >>> nonzero_num = ops.count_nonzero(x=x, axis=[1,]) 80 >>> print(nonzero_num) 81 [1 2] 82 >>> # case 5: keep_dims value was specified. 83 >>> nonzero_num = ops.count_nonzero(x=x, keep_dims=True) 84 >>> print(nonzero_num) 85 [[3]] 86 >>> # case 6: keep_dims and axis value was specified. 87 >>> nonzero_num = ops.count_nonzero(x=x, axis=[0,], keep_dims=True) 88 >>> print(nonzero_num) 89 [[1 2 0]] 90 """ 91 92 const_utils.check_type_valid(F.dtype(x), mstype.number_type, 'input x') 93 axis = _check_validate_axis(axis, "count_nonzero") 94 keep_dims = _check_validate_keepdims(keep_dims, "count_nonzero") 95 const_utils.check_type_valid(dtype, mstype.number_type + (mstype.bool_,), 'dtype') 96 97 not_equal = P.NotEqual() 98 cast = P.Cast() 99 reduce_sum = P.ReduceSum(keep_dims) 100 nonzero_bool = not_equal(x, 0) 101 # ReduceSum only support float16 or float32 tensor. 102 nonzero_val = cast(nonzero_bool, mstype.float32) 103 nonzero_num = cast(reduce_sum(nonzero_val, axis), dtype) 104 105 return nonzero_num 106 107# tensor dot 108 109 110@constexpr 111def _int_to_tuple_conv(axes): 112 """ 113 Converts ints to tuples in input axes, expected by most validation checks. 114 """ 115 for x in [0, 1]: 116 if isinstance(axes[x], int): 117 axes[x] = (axes[x],) 118 return axes 119 120 121@constexpr 122def _check_axes(axes, prim_name=None): 123 """ 124 Check for validity and type of axes passed to function. 125 """ 126 msg_prefix = f"For '{prim_name}', the" if prim_name else "The" 127 validator.check_value_type('axes', axes, [int, tuple, list], "tensor dot") 128 if not isinstance(axes, int): 129 axes = list(axes) # to avoid immutability issues 130 if len(axes) != 2: 131 raise ValueError(f"{msg_prefix} dimension of 'axes' should be 2, but got 'axes': {axes}.") 132 axes = _int_to_tuple_conv(axes) # convert before length checks 133 if len(axes[0]) != len(axes[1]): 134 raise ValueError(f"{msg_prefix} first and second dim of 'axes' have to be the same size/length, " 135 f"but got 'axes': {axes}.") 136 if len(axes[0]) != len(set(axes[0])) or len(axes[1]) != len(set(axes[1])): 137 raise ValueError(f"{msg_prefix} 'axes' cannot have duplicating values, but got {axes}.") 138 return axes 139 140 141@constexpr 142def _typecheck_input(x1_type, x2_type, prim_name=None): 143 """ 144 Check input tensor types to be valid and confirm they are the same type. 145 """ 146 msg_prefix = f"For '{prim_name}', the" if prim_name else "The" 147 const_utils.check_type_valid(x1_type, [mstype.float32, mstype.float16], 'x1') 148 const_utils.check_type_valid(x2_type, [mstype.float32, mstype.float16], 'x2') 149 if x1_type != x2_type: 150 raise TypeError(f"{msg_prefix} inputs must be the same type, but got x1_type: {x1_type} " 151 f"and x2_type: {x2_type}.") 152 153 154@constexpr 155def _axes_int_check(x1_shape, x2_shape, axes, prim_name=None): 156 """ 157 Convert from single int axes to 2d tuple if required 158 """ 159 msg_prefix = f"For '{prim_name}', the" if prim_name else "The" 160 if isinstance(axes, int): 161 if axes < 0: 162 raise ValueError(f"{msg_prefix} axes must be at least 0, but got {axes}.") 163 if axes == 0: 164 # outer product, no input validation required 165 return [], [] 166 if axes > len(x1_shape) or axes > len(x2_shape): 167 raise ValueError(f"{msg_prefix} axes cannot be greater than the length of x1_shape and x2_shape, " 168 f"but got axes: {axes}, x1_shape: {x1_shape}, x2_shape: {x2_shape}.") 169 x1_ind = tuple(range(len(x1_shape))[-1 * axes:]) 170 x2_ind = tuple(range(len(x2_shape))[:axes]) 171 axes = tuple((x1_ind, x2_ind)) 172 axes = _int_to_tuple_conv(axes) 173 return axes 174 175 176@constexpr 177def _validate_axes(x1_shape, x2_shape, axes, prim_name=None): 178 """ 179 Checks for axes having the correct length according to input, for any value in axis 180 being out of range with given shape and also checking for compatible axes values 181 with given inputs. 182 """ 183 msg_prefix = f"For '{prim_name}', the" if prim_name else "The" 184 shapes = [x1_shape, x2_shape] 185 186 # axis length check 187 for ix_input, x_axes in enumerate(axes): 188 axes_len = len(x_axes) 189 shape_dim_len = len(shapes[ix_input]) 190 if axes_len > shape_dim_len: 191 raise ValueError(f"{msg_prefix} length of x_axes should be less than or equal to {shape_dim_len}, " 192 f"but got 'len(x_axes)': {axes_len}.") 193 194 # axis values range check 195 for ix_input, x_axes in enumerate(axes): 196 comp_shape = shapes[ix_input] 197 max_val = len(comp_shape) - 1 198 min_val = -1 * len(comp_shape) 199 for _, x_value in enumerate(x_axes): 200 if not min_val <= x_value <= max_val: 201 raise ValueError(f"{msg_prefix} value in axes should be in range: [{min_val}, {max_val}], " 202 f"but got {x_value}.") 203 204 # check axis value with input shape - both ways for axis valid 205 invalid_a = False 206 invalid_b = False 207 for i in range(len(axes[0])): # sizes already validated 208 if x1_shape[axes[0][i]] != x2_shape[axes[1][i]]: 209 invalid_a = True 210 if x1_shape[axes[0][i]] != x2_shape[axes[1][len(axes[0])-1-i]]: 211 invalid_b = True 212 if invalid_a and invalid_b: 213 raise ValueError(f"{msg_prefix} 'i' should exist such that 'x1_shape[axes[0][i]]' is equal to " 214 f"'x2_shape[axes[1][i]]' or 'x2_shape[axes[1][len(axes[0])-1-i]]', but got " 215 f"x1_shape: {x1_shape}, x2_shape: {x2_shape}, axes: {axes}.") 216 217 218@constexpr 219def _calc_new_shape(shape, axes, position=0): 220 """ 221 Calculate transpose and reshape parameters for input transformations, 222 'position' refers to whether tensor is first or second in the op. 223 """ 224 contraction_axes = tuple(i if i >= 0 else i + len(shape) for i in axes[position]) 225 prod_contraction = int(np.prod([shape[i] for i in contraction_axes])) 226 free_axes = tuple(i for i in range(len(shape)) if i not in contraction_axes) 227 free_dims = tuple(shape[i] for i in free_axes) 228 prod_free = int(np.prod(free_dims)) 229 230 transpose_perm = contraction_axes + free_axes if position else free_axes + contraction_axes 231 new_shape = (prod_contraction, prod_free) if position else (prod_free, prod_contraction) 232 return new_shape, transpose_perm, free_dims 233 234 235def tensor_dot(x1, x2, axes, prim_name='tensor_dot'): 236 """ 237 Computation of Tensor contraction on arbitrary axes between tensors `a` and `b`. 238 239 Contraction allows for the summation of products of elements of `a` and `b` on specified axes. 240 The same number of axes must be specified for both x1 and x2, and values must be within range 241 of number of dims of both `a` and `b`. 242 243 Selected dims in both inputs must also match. 244 245 axes = 0 leads to outer product 246 axes = 1 leads to normal matrix multiplication when inputs both 2D. 247 axes = 1 is the same as axes = ((1,),(0,) where both `a` and `b` are 2D. 248 axes = 2 is the same as axes = ((1,2),(0,1)) where both `a` and `b` are 3D. 249 250 Inputs: 251 - **x1** (Tensor) - First tensor in tensor_dot with datatype float16 or float32 252 - **x2** (Tensor) - Second tensor in tensor_dot with datatype float16 or float32 253 - **axes** (Union[int, tuple(int), tuple(tuple(int)), list(list(int))]) - Single value or 254 tuple/list of length 2 with dimensions specified for `a` and `b` each. If single value `N` passed, 255 automatically picks up last N dims from `a` input shape and first N dims from `b` input shape in order 256 as axes for each respectively. 257 258 Outputs: 259 Tensor, the shape of the output tensor is :math:`(N + M)`. Where :math:`N` and :math:`M` are the free axes not 260 contracted in both inputs 261 262 Raises: 263 TypeError: If `x1` or `x2` is not a Tensor. 264 TypeError: If `axes` is not one of the following: int, tuple, list. 265 266 Supported Platforms: 267 ``Ascend`` ``GPU`` ``CPU`` 268 269 Examples: 270 >>> input_x1 = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32) 271 >>> input_x2 = Tensor(np.ones(shape=[3, 1, 2]), mindspore.float32) 272 >>> output = ops.tensor_dot(input_x1, input_x2, ((0,1),(1,2))) 273 >>> print(output) 274 [[2. 2. 2] 275 [2. 2. 2] 276 [2. 2. 2]] 277 """ 278 shape_op = P.Shape() 279 reshape_op = P.Reshape() 280 transpose_op = P.Transpose() 281 matmul_op = P.MatMul(False, False) 282 # input validity checks 283 x1_shape = shape_op(x1) 284 x2_shape = shape_op(x2) 285 x1_type = F.dtype(x1) 286 x2_type = F.dtype(x2) 287 axes = _check_axes(axes, prim_name) 288 _typecheck_input(x1_type, x2_type, prim_name) 289 # input compatibility check & axes format update 290 axes = _axes_int_check(x1_shape, x2_shape, axes, prim_name) 291 _validate_axes(x1_shape, x2_shape, axes, prim_name) 292 x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape(x1_shape, axes, 0) 293 x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape(x2_shape, axes, 1) 294 output_shape = x1_ret + x2_ret # combine free axes from both inputs 295 # run tensor_dot op 296 x1_transposed = transpose_op(x1, x1_transpose_fwd) 297 x2_transposed = transpose_op(x2, x2_transpose_fwd) 298 x1_reshaped = reshape_op(x1_transposed, x1_reshape_fwd) 299 x2_reshaped = reshape_op(x2_transposed, x2_reshape_fwd) 300 mul_result = matmul_op(x1_reshaped, x2_reshaped) 301 final_result = reshape_op(mul_result, output_shape) 302 return final_result 303 304 305@constexpr 306def _check_invalid_input(x1_shape, x2_shape, prim_name=None): 307 msg_prefix = f"For '{prim_name}', the" if prim_name else "The" 308 if len(x1_shape) < 2 or len(x2_shape) < 2: 309 raise ValueError(f"{msg_prefix} inputs x1, x2 should have 'dimension >= 2'," 310 f"but got 'len(x1_shape)': ({len(x1_shape)}) and 'len(x2_shape)': ({len(x2_shape)}).") 311 312 313@constexpr 314def _typecheck_input_dot(x1_type, x2_type, prim_name=None): 315 """ 316 Check input tensor types to be valid and confirm they are the same type for dot and batch dot ops. 317 """ 318 msg_prefix = f"For '{prim_name}', the" if prim_name else "The" 319 const_utils.check_type_valid(x1_type, [mstype.float16, mstype.float32], 'x1') 320 const_utils.check_type_valid(x2_type, [mstype.float16, mstype.float32], 'x2') 321 if x1_type != x2_type: 322 raise TypeError(f"{msg_prefix} inputs must be the same type, but got " 323 f"x1_type: {x1_type} and x2_type: {x2_type}.") 324 325 326@constexpr 327def _get_transpose_shape(x2_shape): 328 x2_shape_range = tuple(range(len(x2_shape))) 329 x2_shape_transpose = x2_shape_range[-2:-1] + x2_shape_range[:-2] + x2_shape_range[-1:] 330 return x2_shape_transpose 331 332 333def dot(x1, x2, prim_name=None): 334 """ 335 Computation a dot product between samples in two tensors. 336 337 Inputs: 338 - **x1** (Tensor) - First tensor in Dot op with datatype float16 or float32 339 The rank must be greater than or equal to 2. 340 - **x2** (Tensor) - Second tensor in Dot op with datatype float16 or float32 341 The rank must be greater than or equal to 2. 342 343 Outputs: 344 Tensor, dot product of x1 and x2. 345 346 Raises: 347 TypeError: If type of x1 and x2 are not the same. 348 TypeError: If dtype of x1 or x2 is not float16 or float32. 349 ValueError: If rank of x1 or x2 less than 2. 350 351 Supported Platforms: 352 ``Ascend`` ``GPU`` ``CPU`` 353 354 Examples: 355 >>> input_x1 = Tensor(np.ones(shape=[2, 3]), mindspore.float32) 356 >>> input_x2 = Tensor(np.ones(shape=[1, 3, 2]), mindspore.float32) 357 >>> output = ops.dot(input_x1, input_x2) 358 >>> print(output) 359 [[[3. 3.]] 360 [[3. 3.]]] 361 >>> print(output.shape) 362 (2, 1, 2) 363 >>> input_x1 = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32) 364 >>> input_x2 = Tensor(np.ones(shape=[1, 3, 2]), mindspore.float32) 365 >>> output = ops.dot(input_x1, input_x2) 366 >>> print(output) 367 [[[[3. 3.]] 368 [[3. 3.]]]] 369 >>> print(output.shape) 370 (1, 2, 1, 2) 371 >>> input_x1 = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32) 372 >>> input_x2 = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32) 373 >>> output = ops.dot(input_x1, input_x2) 374 >>> print(output) 375 [[[[3. 3.] 376 [3. 3.]] 377 [[3. 3.] 378 [3. 3.]]]] 379 >>> print(output.shape) 380 (1, 2, 2, 2) 381 >>> input_x1 = Tensor(np.ones(shape=[3, 2, 3]), mindspore.float32) 382 >>> input_x2 = Tensor(np.ones(shape=[2, 1, 3, 2]), mindspore.float32) 383 >>> output = ops.dot(input_x1, input_x2) 384 >>> print(output) 385 [[[[[3. 3.]] 386 [[3. 3.]]] 387 [[[3. 3.]] 388 [[3. 3.]]]] 389 [[[[3. 3.]] 390 [[3. 3.]]] 391 [[[3. 3.]] 392 [[3. 3.]]]] 393 [[[[3. 3.]] 394 [[3. 3.]]] 395 [[[3. 3.]] 396 [[3. 3.]]]]] 397 >>> print(output.shape) 398 (3, 2, 2, 1, 2) 399 """ 400 shape_op = P.Shape() 401 reshape_op = P.Reshape() 402 transpose_op = P.Transpose() 403 matmul_op = P.MatMul(False, False) 404 x1_shape = shape_op(x1) 405 x2_shape = shape_op(x2) 406 x1_type = F.dtype(x1) 407 x2_type = F.dtype(x2) 408 _typecheck_input_dot(x1_type, x2_type, prim_name) 409 _check_invalid_input(x1_shape, x2_shape, prim_name) 410 411 if len(x1_shape) > 2 or len(x2_shape) > 2: 412 x2_shape_transpose = _get_transpose_shape(x2_shape) 413 x2_transpose = transpose_op(x2, x2_shape_transpose) 414 x1_reshape = reshape_op(x1, (-1, x1_shape[-1])) 415 x2_reshape = reshape_op(x2_transpose, (x2_shape[-2], -1)) 416 mul_result = matmul_op(x1_reshape, x2_reshape) 417 return reshape_op(mul_result, x1_shape[:-1] + x2_shape[:-2] + x2_shape[-1:]) 418 return matmul_op(x1, x2) 419 420 421@constexpr 422def _get_batch_size(x1_shape, x2_shape, prim_name=None): 423 """ 424 Get batch sizes from two inputs 425 """ 426 msg_prefix = f"For '{prim_name}', the" if prim_name else "The" 427 if len(x1_shape) < 2 or len(x2_shape) < 2: 428 raise ValueError(f"{msg_prefix} inputs x1, x2 should have 'dimension >= 2', " 429 f"but got 'len(x1_shape)': ({len(x1_shape)}) and 'len(x2_shape)': ({len(x2_shape)}).") 430 return x1_shape[0], x2_shape[0] 431 432 433@constexpr 434def _typecheck_input_batch_dot(x1_type, x2_type, prim_name=None): 435 """ 436 Check input tensor types to be valid and confirm they are the same type for batch dot ops. 437 """ 438 msg_prefix = f"For '{prim_name}', the" if prim_name else "The" 439 const_utils.check_type_valid(x1_type, [mstype.float32], 'x1') 440 const_utils.check_type_valid(x2_type, [mstype.float32], 'x2') 441 if x1_type != x2_type: 442 raise TypeError(f"{msg_prefix} inputs must be the same type, but got x1_type: {x1_type} and " 443 f"x2_type: {x2_type}.") 444 445 446@constexpr 447def _check_axes_for_batch_dot(x1_shape, x2_shape, axes, prim_name=None): 448 """ 449 Check whether axes are valid and cast axes from tuple to list 450 """ 451 msg_prefix = f"For '{prim_name}', the" if prim_name else "The" 452 if axes is None: 453 if len(x2_shape) == 2: 454 axes = [len(x1_shape) - 1, len(x2_shape) - 1] 455 else: 456 axes = [len(x1_shape) - 1, len(x2_shape) - 2] 457 458 if isinstance(axes, (list, tuple)): 459 if 0 in axes: 460 raise ValueError(f"{msg_prefix} axes cannot contain 0, but got axes: {axes}.") 461 if len(axes) != 2: 462 raise ValueError(f"{msg_prefix} length of axes must be equal to 2, but got {len(axes)}.") 463 if isinstance(axes, tuple): 464 axes = list(axes) 465 validator.check_value_type('axes[0]', axes[0], [int], 'batch_dot') 466 validator.check_value_type('axes[1]', axes[1], [int], 'batch_dot') 467 # Reverse if axis < 0 468 if axes[0] < 0: 469 axes[0] += len(x1_shape) 470 if axes[1] < 0: 471 axes[1] += len(x2_shape) 472 validator.check_non_negative_int(axes[0], 'reversed axes[0]', 'batch_dot') 473 validator.check_non_negative_int(axes[1], 'reversed axes[1]', 'batch_dot') 474 if axes[0] > len(x1_shape) or axes[1] > len(x2_shape): 475 raise ValueError(f"{msg_prefix} axes[0] must be less than or equal to len(x1_shape), " 476 f"and axes[1] must be less than or equal to len(x2_shape)." 477 f"But got axes: {axes}, x1_shape: {x1_shape}, x2_shape: {x2_shape}.") 478 elif isinstance(axes, int): 479 if axes == 0: 480 raise ValueError(f"{msg_prefix} axes should not equal to 0, but got {axes}.") 481 if axes < 0: 482 axes = [axes + len(x1_shape), axes + len(x2_shape)] 483 validator.check_non_negative_int(axes[0], 'reversed axes', 'batch_dot') 484 elif axes > len(x1_shape) or axes > len(x2_shape): 485 raise ValueError(f"{msg_prefix} axes cannot be greater than the length of x1_shape and x2_shape, " 486 f"but got axes: {axes}, x1_shape: {x1_shape}, x2_shape: {x2_shape}.") 487 else: 488 axes = [axes, axes] 489 else: 490 raise ValueError(f"{msg_prefix} type of axes must be one of those: int, tuple(int), list(int), " 491 f"but got {type(axes).__name__}.") 492 return axes 493 494 495@constexpr 496def _calc_new_shape_batchdot(shape, axes, position=0): 497 """ 498 Calculate transpose and reshape parameters for input transformations, 499 'position' refers to whether tensor is first or second in the op. 500 """ 501 axis = axes[position] 502 contraction_axes = tuple([axis]) 503 prod_contraction = int(np.prod([shape[i] for i in contraction_axes])) 504 free_axes = tuple(i for i in range(1, len(shape)) if i not in contraction_axes) 505 free_dims = tuple(shape[i] for i in free_axes) 506 prod_free = int(np.prod(free_dims)) 507 508 transpose_perm = contraction_axes + free_axes if position else free_axes + contraction_axes 509 transpose_perm = tuple([0]) + transpose_perm 510 new_shape = (prod_contraction, prod_free) if position else (prod_free, prod_contraction) 511 new_shape = tuple([shape[0]]) + new_shape 512 return new_shape, transpose_perm, free_dims 513 514 515@constexpr 516def _check_batch_size(x1_batch_size, x2_batch_size, prim_name=None): 517 """ 518 Check whether batch size of two inputs are the same 519 """ 520 msg_prefix = f"For '{prim_name}', the" if prim_name else "The" 521 if x1_batch_size != x2_batch_size: 522 raise ValueError(f"{msg_prefix} both inputs x1, x2 should have the same batch sizes, but got " 523 f"x1_batch_size: {x1_batch_size} and x2_batch_size: {x2_batch_size}.") 524 525 526@constexpr 527def _get_output_shape(batch_size, x1_ret, x2_ret): 528 """ 529 Compute output shape for batch dot 530 """ 531 output_shape = tuple([batch_size]) + x1_ret + x2_ret 532 return output_shape 533 534 535def batch_dot(x1, x2, axes=None, prim_name=None): 536 """ 537 Computation of batch dot product between samples in two tensors containing batch dims. 538 539 .. math:: 540 output = x1[batch, :] * x2[batch, :] 541 542 Inputs: 543 - **x1** (Tensor) - First tensor in Batch Dot op with datatype float32 and the rank of `x1` must be greater 544 than or equal to 2. 545 - **x2** (Tensor) - Second tensor in Batch Dot op with datatype float32. The datatype of `x2` should 546 be same as `x1` and the rank of `x2` must be greater than or equal to 2. 547 - **axes** (Union[int, tuple(int), list(int)]) - Single value or tuple/list of length 2 with dimensions 548 specified for `a` and `b` each. If single value `N` passed, automatically picks up last N dims from 549 `a` input shape and last N dimensions from `b` input shape in order as axes for each respectively. 550 Default: None. 551 552 Outputs: 553 Tensor, batch dot product of `x1` and `x2`.For example: The Shape of output 554 for input `x1` shapes (batch, d1, axes, d2) and `x2` shapes (batch, d3, axes, d4) is (batch, d1, d2, d3, d4), 555 where d1 and d2 means any number. 556 557 Raises: 558 TypeError: If type of x1 and x2 are not the same. 559 TypeError: If dtype of x1 or x2 is not float32. 560 ValueError: If rank of x1 or x2 less than 2. 561 ValueError: If batch dim used in axes. 562 ValueError: If len(axes) less than 2. 563 ValueError: If axes is not one of those: None, int, (int, int). 564 ValueError: If axes reversed from negative int is too low for dimensions of input arrays. 565 ValueError: If axes value is too high for dimensions of input arrays. 566 ValueError: If batch size of x1 and x2 are not the same. 567 568 Supported Platforms: 569 ``Ascend`` ``GPU`` ``CPU`` 570 571 Examples: 572 >>> x1 = Tensor(np.ones(shape=[2, 2, 3]), mindspore.float32) 573 >>> x2 = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32) 574 >>> axes = (-1, -2) 575 >>> output = ops.batch_dot(x1, x2, axes) 576 >>> print(output) 577 [[[3. 3.] 578 [3. 3.]] 579 [[3. 3.] 580 [3. 3.]]] 581 >>> x1 = Tensor(np.ones(shape=[2, 2]), mindspore.float32) 582 >>> x2 = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32) 583 >>> axes = (1, 2) 584 >>> output = ops.batch_dot(x1, x2, axes) 585 >>> print(output) 586 [[2. 2. 2.] 587 [2. 2. 2.]] 588 >>> print(output.shape) 589 (2, 3) 590 >>> x1 = Tensor(np.ones(shape=[6, 2, 3, 4]), mindspore.float32) 591 >>> x2 = Tensor(np.ones(shape=[6, 5, 4, 8]), mindspore.float32) 592 >>> output = ops.batch_dot(x1, x2) 593 >>> print(output.shape) 594 (6, 2, 3, 5, 8) 595 >>> x1 = Tensor(np.ones(shape=[2, 2, 4]), mindspore.float32) 596 >>> x2 = Tensor(np.ones(shape=[2, 5, 4, 5]), mindspore.float32) 597 >>> output = ops.batch_dot(x1, x2) 598 >>> print(output.shape) 599 (2, 2, 5, 5) 600 601 """ 602 603 transpose_op = P.Transpose() 604 batch_matmul_op = P.BatchMatMul() 605 squeeze_one_op = P.Squeeze(1) 606 squeeze_minus_one_op = P.Squeeze(-1) 607 # input validity checks 608 x1_shape = F.shape(x1) 609 x2_shape = F.shape(x2) 610 x1_dim_num = len(x1_shape) 611 x2_dim_num = len(x2_shape) 612 x1_type = F.dtype(x1) 613 x2_type = F.dtype(x2) 614 615 x1_batch_size, x2_batch_size = _get_batch_size(x1_shape, x2_shape, prim_name) 616 617 _typecheck_input_batch_dot(x1_type, x2_type, prim_name) 618 _check_batch_size(x1_batch_size, x2_batch_size, prim_name) 619 axes = _check_axes_for_batch_dot(x1_shape, x2_shape, axes, prim_name) 620 621 if x1_dim_num == 2: 622 x1 = F.expand_dims(x1, 1) 623 axes[0] += 1 624 if x2_dim_num == 2: 625 x2 = F.expand_dims(x2, 2) 626 627 x1_shape = F.shape(x1) 628 x2_shape = F.shape(x2) 629 630 x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape_batchdot(x1_shape, axes, 0) 631 x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape_batchdot(x2_shape, axes, 1) 632 output_shape = _get_output_shape(x1_batch_size, x1_ret, x2_ret) 633 634 x1_transposed = transpose_op(x1, x1_transpose_fwd) 635 x2_transposed = transpose_op(x2, x2_transpose_fwd) 636 x1_reshaped = F.reshape(x1_transposed, x1_reshape_fwd) 637 x2_reshaped = F.reshape(x2_transposed, x2_reshape_fwd) 638 639 # Batch matmal op part 640 mul_result = batch_matmul_op(x1_reshaped, x2_reshaped) 641 642 final_result = F.reshape(mul_result, output_shape) 643 644 # if the original dims are expanded, restore them from 3 to 2 645 if x1_dim_num == 2: 646 final_result = squeeze_one_op(final_result) 647 elif x2_dim_num == 2: 648 final_result = squeeze_minus_one_op(final_result) 649 650 return final_result 651 652 653@constexpr 654def _check_same_type(dtype1, dtype2): 655 return dtype1 == dtype2 656 657 658@constexpr 659def _max(*args): 660 """Returns the maximum value.""" 661 return max(*args) 662 663 664@constexpr 665def _min(*args): 666 """Returns the minimum value.""" 667 return min(*args) 668 669 670@constexpr 671def _infer_shape_rem(shape1, shape2, ndim1, ndim2, transpose_b): 672 """Infers the shape of the last two dimensions after performing matmul.""" 673 shape_rem = [] 674 if ndim1 >= 2: 675 shape_rem.append(shape1[-2]) 676 if transpose_b: 677 if ndim2 >= 2: 678 shape_rem.append(shape2[-2]) 679 else: 680 if ndim1 >= 1: 681 shape_rem.append(shape2[-1]) 682 return tuple(shape_rem) 683 684 685@constexpr 686def _check_matmul_shapes(shape1, shape2, prim_name=None): 687 """Checks shape1 and shape2 are valid to perform matmul, and returns output shape after broadcasting.""" 688 msg_prefix = f"For '{prim_name}', the" if prim_name else "The" 689 ndim1, ndim2 = len(shape1), len(shape2) 690 if ndim1 < 1 or ndim2 < 1: 691 raise ValueError(f"{msg_prefix} dimension of input operands must be at least 1, but got " 692 f"the length of shape1: {ndim1}, the length of shape2: {ndim2}.") 693 if ndim2 >= 2 and shape1[-1] != shape2[-2]: 694 raise ValueError(f"{msg_prefix} shape1[-1] should be equal to shape2[-2] when the length of shape2 " 695 f"is greater than or equal to 2, but got shape1[-1]: {shape1[-1]}, " 696 f"shape2[-2]: {shape2[-2]}.") 697 shape_out = deque() 698 for items in zip_longest(reversed(shape1[:-2]), reversed(shape2[:-2]), fillvalue=1): 699 max_size = max(items) 700 if any(item not in (1, max_size) for item in items): 701 raise ValueError(f"{msg_prefix} operands could not be broadcast together with shape1 {shape1} and " 702 f"shape2 {shape2}.") 703 shape_out.appendleft(max_size) 704 return tuple(shape_out) 705 706 707@constexpr 708def _tile_size(shape, out_shape, ndim): 709 """Returns tile_size such that shape*tile_size = out_shape""" 710 size = [1] * ndim 711 for idx, (i, j) in enumerate(zip(shape, out_shape)): 712 if i != j: 713 size[idx] = j 714 return tuple(size) 715 716 717@constexpr 718def _check_need_broadcast(shape1, shape2): 719 """Returns True if broadcast is necessary for batchmatmul.""" 720 return shape1[:-2] != shape2[:-2] 721 722 723def _expand(x, ndim): 724 """Expand x to ndim from axis, which can be 0 or -1.""" 725 while F.rank(x) < ndim: 726 x = F.expand_dims(x, 0) 727 return x 728 729 730def _broadcast_to(x, shape_cur, shape_to, ndim_to): 731 """Broadcasts x from shape_cur to shape_to.""" 732 size = _tile_size(shape_cur, shape_to, ndim_to) 733 return F.tile(x, size) 734 735 736def matmul(x1, x2, dtype=None, prim_name=None): 737 """ 738 Returns the matrix product of two arrays. 739 740 Note: 741 Numpy arguments `out`, `casting`, `order`, `subok`, `signature`, and `extobj` are 742 not supported. 743 On GPU, the supported dtypes are np.float16 and np.float32. 744 On CPU, the supported dtypes are np.float16 and np.float32. 745 746 Args: 747 x1 (Tensor): Input tensor, scalar not allowed. 748 The last dimension of `x1` must be the same size as the second last dimension of `x2`. 749 And the shape of x1 and x2 could be broadcast. 750 x2 (Tensor): Input tensor, scalar not allowed. 751 The last dimension of `x1` must be the same size as the second last dimension of `x2`. 752 And the shape of x1 and x2 could be broadcast. 753 dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the 754 output Tensor. 755 756 Returns: 757 Tensor or scalar, the matrix product of the inputs. This is a scalar only 758 when both `x1`, `x2` are 1-d vectors. 759 760 Raises: 761 ValueError: If the last dimension of `x1` is not the same size as the 762 second-to-last dimension of `x2`, or if a scalar value is passed in. 763 ValueError: If the shape of `x1` and `x2` could not broadcast together。 764 765 Supported Platforms: 766 ``Ascend`` ``GPU`` ``CPU`` 767 768 Examples: 769 >>> # case 1 : Reasonable application of broadcast mechanism 770 >>> x1 = Tensor(np.arange(2*3*4).reshape(2, 3, 4), mindspore.float32) 771 >>> x2 = Tensor(np.arange(4*5).reshape(4, 5), mindspore.float32) 772 >>> output = ops.matmul(x1, x2) 773 >>> print(output) 774 [[[ 70. 76. 82. 88. 94.] 775 [ 190. 212. 234. 256. 278.] 776 [ 310. 348. 386. 424. 462.]] 777 [[ 430. 484. 538. 592. 646.] 778 [ 550. 620. 690. 760. 830.] 779 [ 670. 756. 842. 928. 1014.]]] 780 >>> print(output.shape) 781 (2, 3, 5) 782 >>> # case 2 : the rank of `x1` is 1 783 >>> x1 = Tensor(np.ones([1, 2]), mindspore.float32) 784 >>> x2 = Tensor(np.ones([2,]), mindspore.float32) 785 >>> output = ops.matmul(x1, x2) 786 >>> print(output) 787 [2.] 788 >>> print(output.shape) 789 (1,) 790 """ 791 # performs type promotion 792 dtype1 = F.dtype(x1) 793 dtype2 = F.dtype(x2) 794 if not _check_same_type(dtype1, dtype2): 795 x1 = x1.astype(mstype.float32) 796 x2 = x2.astype(mstype.float32) 797 798 ndim1_orig, ndim2_orig = F.rank(x1), F.rank(x2) 799 shape1_orig, shape2_orig = F.shape(x1), F.shape(x2) 800 transpose_b = ndim2_orig == 1 801 shape_backbone = _check_matmul_shapes(shape1_orig, shape2_orig, prim_name) 802 # infers the shape of the output 803 shape_out = shape_backbone + _infer_shape_rem(shape1_orig, shape2_orig, 804 ndim1_orig, ndim2_orig, transpose_b) 805 806 x1 = _expand(x1, 2) 807 x2 = _expand(x2, 2) 808 if F.rank(x2) == 2: 809 if F.rank(x1) > 2: 810 x1 = F.reshape(x1, (-1, shape1_orig[-1])) 811 res = P.MatMul(False, transpose_b)(x1, x2) 812 else: 813 # broadcasts x1.shape[:-2] with x2.shape[:-2] 814 ndim_aligned = _max(ndim1_orig, ndim2_orig) 815 x1 = _expand(x1, ndim_aligned) 816 x2 = _expand(x2, ndim_aligned) 817 shape1_aligned, shape2_aligned = F.shape(x1), F.shape(x2) 818 x1 = _broadcast_to(x1, shape1_aligned[:-2], shape_backbone, ndim_aligned) 819 x2 = _broadcast_to(x2, shape2_aligned[:-2], shape_backbone, ndim_aligned) 820 res = P.BatchMatMul(False, transpose_b)(x1, x2) 821 822 if dtype is not None: 823 res = res.astype(dtype) 824 return F.reshape(res, shape_out) 825