• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""Defines sparse operators with functional form."""
17
18from __future__ import absolute_import
19from typing import Tuple
20from mindspore.ops.operations.sparse_ops import (
21    DenseToCSRSparseMatrix,
22    CSRSparseMatrixToSparseTensor,
23    SparseConcat,
24    SparseAdd,
25    SparseMatrixAdd,
26    SparseMatrixSoftmax,
27    SparseMatrixSparseMatMul,
28    CSRSparseMatrixToDense
29)
30from mindspore import ops
31from mindspore.common import dtype as mstype
32from mindspore.ops.primitive import constexpr, Primitive
33from mindspore.ops.operations.array_ops import GatherNd, Coalesce
34from mindspore.ops.operations import _csr_ops
35from mindspore.ops import functional as F
36from mindspore.common import CSRTensor, COOTensor, Tensor
37from mindspore.ops.composite.multitype_ops._constexpr_utils import raise_value_error, raise_type_error, make_tensor,\
38    promote_binary_dtype
39
40# utility functions and values
41gather_nd = GatherNd()
42dense_to_csr = DenseToCSRSparseMatrix()
43csr_sparse_matrix_to_sparse_tensor = CSRSparseMatrixToSparseTensor()
44batch_csr_pointers_empty = Tensor([0, -1], dtype=mstype.int32)
45coalesce_op = Coalesce()
46csr_sparse_matrix_to_dense = CSRSparseMatrixToDense()
47
48
49@constexpr
50def print_info(info):
51    """Print given error info"""
52    print(info)
53
54
55@constexpr
56def _make_tensor(data):
57    """Make Tensor"""
58    return Tensor(data)
59
60
61@constexpr
62def _make_tensor_with_dtype(data, dtype):
63    """Make Tensor with specific datatype"""
64    return Tensor(data, dtype=dtype)
65
66
67def _convert_shape(shape):
68    """Temporary solution to get shape value, will be removed when shape op is supported."""
69    if F.is_sequence_shape_unknown(shape):
70        return (-2,)
71    shape = [-1 if not F.isconstant(i) else i for i in shape]
72    return tuple(shape)
73
74
75def is_scalar(tensor):
76    """Determine whether tensor input is a scalar tensor."""
77    if tensor.size != 1:
78        return False
79    return len(tensor.shape) <= 2
80
81
82def promote_tensor(tensor_1, tensor_2):
83    """promote Tensor"""
84    dtype = promote_binary_dtype(tensor_1.dtype, tensor_2.dtype)
85    return tensor_1.astype(dtype), tensor_2.astype(dtype)
86
87
88def promote_csr(csr_tensor_1, csr_tensor_2):
89    """Type promotion for CSR tensor."""
90    indptr_1, indptr_2 = promote_tensor(csr_tensor_1.indptr, csr_tensor_2.indptr)
91    indices_1, indices_2 = promote_tensor(csr_tensor_1.indices, csr_tensor_2.indices)
92    values_1, values_2 = promote_tensor(csr_tensor_1.values, csr_tensor_2.values)
93    csr_tensor_1 = CSRTensor(indptr_1, indices_1, values_1, csr_tensor_1.shape)
94    csr_tensor_2 = CSRTensor(indptr_2, indices_2, values_2, csr_tensor_2.shape)
95    return csr_tensor_1, csr_tensor_2
96
97
98def promote_coo(coo_tensor_1, coo_tensor_2):
99    """promote COO-Tensor"""
100    indices_1, indices_2 = promote_tensor(coo_tensor_1.indices, coo_tensor_2.indices)
101    values_1, values_2 = promote_tensor(coo_tensor_1.values, coo_tensor_2.values)
102    coo_tensor_1 = COOTensor(indices_1, values_1, coo_tensor_1.shape)
103    coo_tensor_2 = COOTensor(indices_2, values_2, coo_tensor_2.shape)
104    return coo_tensor_1, coo_tensor_2
105
106
107def coalesce(x_indices: Tensor, x_values: Tensor, x_shape: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
108    """
109    Returns the coalesced sparse tensor of the input.
110
111    Args:
112        - **x_indices** (Tensor) - A 2-D Tensor, represents the indices of the nonzero elements of the sparse tensor.
113          Supported data type is int64. It's elements should be non-negative. The shape is :math:`(y, x)`.
114        - **x_values** (Tensor) - A 1-D Tensor, represents the values corresponding to the indices in `x_indices`.
115          Supported data types are float16 and float32. The shape is :math:`(x,)`.
116        - **x_shape** (Tensor) - A 1-D Tensor, specifies the shape of the sparse tensor.
117          Supported data type is int64. The shape is :math:`(y,)`.
118
119    Returns:
120        - **y_indices** (Tensor) - A 2-D Tensor, represents the indices of the nonzero elements of the sparse tensor.
121          Data type is int64. It's elements are non-negative. The shape is :math:`(y, z)`.
122          `z` represents the number of different indices in `x_indices`.
123        - **y_values** (Tensor) - A 1-D Tensor, represents the values corresponding to the indices in `y_indices`.
124          Data type is the same as `x_values`'s. The shape is :math:`(z,)`.
125        - **y_shape** (Tensor) - A 1-D Tensor, specifies the shape of the sparse tensor.
126          Data type is int64. The shape is :math:`(y,)`.
127
128    Raises:
129        TypeError: If the data type of `x_values` is neither float32 nor float16.
130        TypeError: If any of the data types of `x_indices` and `x_shape` is not int64.
131        ValueError: If any of `x_values` and `x_shape` is not a 1-D tensor.
132        ValueError: If `x_indices` is not a 2-D tensor.
133        ValueError: If sizes of second dimension of `x_indices` and first dimension of `x_values` are not the same.
134        ValueError: If sizes of first dimension of `x_indices` and first dimension of `x_shape` are not the same.
135        ValueError: If any of the values of elements of `x_indices` is negative.
136        ValueError: If any of the values of elements of `x_indices` exceed the limit set by `x_shape`.
137
138    Supported Platforms:
139        ``GPU`` ``CPU``
140
141    Examples:
142        >>> import mindspore
143        >>> from mindspore import ops
144        >>> from mindspore import Tensor
145        >>> x_indices = Tensor([[0, 0, 1], [1, 1, 2]], dtype=ms.int64)
146        >>> x_values = Tensor([1, 5, 4], dtype=ms.float32)
147        >>> x_shape = Tensor([3, 3], dtype=ms.int64)
148        >>> y_indices, y_values, y_shape = ops.Coalesce()(x_indices, x_values, x_shape)
149        >>> print(y_indices)
150        [[0 1]
151         [1 2]]
152        >>> print(y_values)
153        [6. 4.]
154        >>> print(y_shape)
155        [3 3]
156    """
157    return coalesce_op(x_indices, x_values, x_shape)
158
159
160coo2csr = _csr_ops.COO2CSR()
161
162coo_tensor_get_dense_shape = Primitive('COOTensorGetDenseShape')
163
164coo_tensor_get_indices = Primitive('COOTensorGetIndices')
165
166coo_tensor_get_values = Primitive('COOTensorGetValues')
167
168
169def csr_div(x: CSRTensor, y: Tensor) -> Tensor:
170    """
171    Returns x / y where x is CSRTensor and y is Tensor.
172
173    Note:
174        This function returns the results of dense Tensor, represents the non-zero
175        values of the CSRTensor. If user expects a CSRTensor as output, please directly
176        use `/` operator instead. Only support dense tensor broadcast to sparse tensor
177        at the moment.
178
179    Args:
180        x (CSRTensor): Sparse CSR Tensor.
181        y (Tensor): Dense Tensor, its shape must be able to broadcast to x.
182
183    Returns:
184        Dense Tensor, represents the non-zero values of the result.
185
186    Supported Platforms:
187        ``GPU`` ``CPU``
188    """
189    if isinstance(y, (int, float, bool)):
190        y = _make_tensor(y)
191    if is_scalar(y):
192        if y.ndim > x.ndim:
193            raise_value_error("dense tensor cannot broadcast to the sparse tensor.")
194        return (x.values / y).reshape(x.values.shape)
195    x_values, y = promote_tensor(x.values, y)
196    res_values = _csr_ops.CSRDiv()(x.indptr, x.indices, x_values, x.shape, y)
197    return CSRTensor(x.indptr, x.indices, res_values, x.shape)
198
199
200csr_gather = _csr_ops.CSRGather()
201
202
203def csr_mul(x: CSRTensor, y: Tensor) -> CSRTensor:
204    """
205    Returns x * y where x is CSRTensor and y is Tensor.
206
207    Args:
208        x (CSRTensor): Sparse CSR Tensor.
209        y (Tensor): Dense Tensor, its shape must be able to broadcast to x.
210
211    Returns:
212        CSRTensor.
213
214    Supported Platforms:
215        ``GPU`` ``CPU``
216    """
217    if isinstance(y, (int, float, bool)):
218        y = _make_tensor(y)
219    if is_scalar(y):
220        if y.ndim > x.ndim:
221            raise_value_error("dense tensor cannot broadcast to the sparse tensor.")
222        return (x.values * y).reshape(x.values.shape)
223    x_values, y = promote_tensor(x.values, y)
224    res_values = _csr_ops.CSRMul()(x.indptr, x.indices, x_values, x.shape, y)
225    return CSRTensor(x.indptr, x.indices, res_values, x.shape)
226
227
228def csr_mv(csr_tensor: CSRTensor, dense: Tensor) -> Tensor:
229    """
230    Sparse matrix-vector multiplication.
231
232    Args:
233        csr_tensor (CSRTensor): Sparse CSR Tensor.
234        dense (Tensor): Dense Tensor.
235
236    Returns:
237        Dense Tensor.
238
239    Supported Platforms:
240        ``GPU`` ``CPU``
241    """
242    csr_tensor_values, dense = promote_tensor(csr_tensor.values, dense)
243    return _csr_ops.CSRMV()(csr_tensor.indptr, csr_tensor.indices, csr_tensor_values, csr_tensor.shape, dense)
244
245
246def csr_mm(a: CSRTensor, b: CSRTensor, trans_a: bool = False, trans_b: bool = False,
247           adjoint_a: bool = False, adjoint_b: bool = False):
248    """
249    Return the matrix multiplication result of the right-multiply matrix (dense or CSRTensor) of the CSRTensor.
250    The CSRTensor with shape `[M, N]` needs to adapt the right matrix with shape `[N, K]`
251    to get the dense matrix or CSRTensor with result `[M, K]`.
252
253    Note:
254        If right matrix is CSRTensor, currently only supports GPU backend.
255        If right matrix is Tensor, currently supports CPU backend with LLVM no lower than 12.0.1, and GPU backend.
256
257    Args:
258        a (CSRTensor): Sparse CSR Tensor, rank should be 2.
259        b (CSRTensor): Sparse CSR Tensor, rank should be 2.
260        trans_a (bool, optional): whether to transpose CSRTensor a. Default: ``False`` .
261        trans_b (bool, optional): whether to transpose CSRTensor b. Default: ``False`` .
262        adjoint_a (bool, optional): whether to adjoint CSRTensor a. Default: ``False`` .
263        adjoint_b (bool, optional): whether to adjoint CSRTensor b. Default: ``False`` .
264
265    Returns:
266        CSRTensor.
267
268    Supported Platforms:
269        ``GPU``
270
271    Examples:
272        >>> from mindspore import Tensor, CSRTensor
273        >>> from mindspore import dtype as mstype
274        >>> from mindspore import ops
275        >>> a_shape = (4, 5)
276        >>> a_indptr = Tensor([0, 1, 1, 3, 4], dtype=mstype.int32)
277        >>> a_indices = Tensor([0, 3, 4, 0],dtype=mstype.int32)
278        >>> a_values = Tensor([1.0, 5.0, -1.0, -2.0], dtype=mstype.float32)
279        >>> b_shape = (5, 3)
280        >>> b_indptr = Tensor([0, 1, 1, 3, 3, 3], dtype=mstype.int32)
281        >>> b_indices = Tensor([0, 0, 1],dtype=mstype.int32)
282        >>> b_values = Tensor([2.0, 7.0, 8.0], dtype=mstype.float32)
283        >>> a = CSRTensor(a_indptr, a_indices, a_values, a_shape)
284        >>> b = CSRTensor(b_indptr, b_indices, b_values, b_shape)
285        >>> c = ops.csr_mm(a, b)
286        >>> print(c.shape)
287        (4, 3)
288        >>> print(c.values)
289        [2. -4.]
290        >>> print(c.indptr)
291        [0 1 1 1 2]
292        >>> print(c.indices)
293        [0 0]
294    """
295    if not isinstance(a, CSRTensor) or not isinstance(b, CSRTensor):
296        raise_type_error("For functional operator csr_mm, inputs a and b must be type of CSRTensor currently.")
297    a_batch_pointers = make_tensor([0, a.values.shape[0]], a.indices.dtype)
298    b_batch_pointers = make_tensor([0, b.values.shape[0]], b.indices.dtype)
299    a_shape = make_tensor(a.shape, a.indices.dtype)
300    b_shape = make_tensor(b.shape, b.indices.dtype)
301    sparse_matrix_sparse_matmul = SparseMatrixSparseMatMul(transpose_a=trans_a,
302                                                           transpose_b=trans_b,
303                                                           adjoint_a=adjoint_a,
304                                                           adjoint_b=adjoint_b)
305
306    _, _, c_indptr, c_indices, c_values = sparse_matrix_sparse_matmul(a_shape,
307                                                                      a_batch_pointers,
308                                                                      a.indptr,
309                                                                      a.indices,
310                                                                      a.values,
311                                                                      b_shape,
312                                                                      b_batch_pointers,
313                                                                      b.indptr,
314                                                                      b.indices,
315                                                                      b.values)
316    if trans_a or adjoint_a:
317        return CSRTensor(c_indptr, c_indices, c_values, (a.shape[1], b.shape[1]))
318    if trans_b or adjoint_b:
319        return CSRTensor(c_indptr, c_indices, c_values, (a.shape[0], b.shape[0]))
320    return CSRTensor(c_indptr, c_indices, c_values, (a.shape[0], b.shape[1]))
321
322
323def csr_reduce_sum(csr_tensor: CSRTensor, axis: int) -> Tensor:
324    """
325    Reduces a dimension of a CSRTensor by summing all elements in the dimension.
326
327    Args:
328        csr_tensor (CSRTensor): Sparse CSR Tensor.
329        axis (int): Axis to be reduced.
330
331    Returns:
332        Dense Tensor, represents the non-zero values of the result.
333
334    Supported Platforms:
335        ``GPU`` ``CPU``
336    """
337    return _csr_ops.CSRReduceSum()(csr_tensor.indptr, csr_tensor.indices, csr_tensor.values, csr_tensor.shape, axis)
338
339
340def csr_to_coo(tensor: CSRTensor) -> COOTensor:
341    """
342    Converts a CSRTensor to COOTensor.
343
344    Note:
345        Only 2-D CSRTensor is supported for now.
346
347    Args:
348        tensor (CSRTensor): A CSRTensor, must be 2-D.
349
350    Returns:
351        2D COOTensor, the input tensor stored in COO format.
352
353    Raises:
354        TypeError: If input is not a COOTensor.
355        ValueError: If input tensor is not 2-D.
356
357    Supported Platforms:
358        ``Ascend`` ``GPU`` ``CPU``
359
360    Examples:
361        >>> from mindspore import Tensor, CSRTensor
362        >>> indptr = Tensor([0, 1, 2]).astype("int32")
363        >>> indices = Tensor([0, 1]).astype("int32")
364        >>> values = Tensor([2, 1]).astype("float32")
365        >>> shape = (2, 4)
366        >>> x = CSRTensor(indptr, indices, values, shape)
367        >>> output = ops.csr_to_coo(x)
368        >>> print(output.indices)
369        [[0 0]
370        [1 1]]
371    """
372    if not isinstance(tensor, CSRTensor):
373        raise_type_error("For functional operator csr_to_coo, input argument must be a CSRTensor.")
374    if len(_convert_shape(tensor.shape)) > 2:
375        raise_value_error("Currently only support 2-D CSRTensor when converting to COOTensor.")
376    shape = tensor.shape
377    indices, values, _ = csr_sparse_matrix_to_sparse_tensor(Tensor(shape, mstype.int32), batch_csr_pointers_empty,
378                                                            tensor.indptr, tensor.indices, tensor.values)
379    return COOTensor(indices, values, _convert_shape(shape))
380
381
382def csr_to_dense(csr_tensor: CSRTensor) -> Tensor:
383    """
384    Converts a CSRTensor to its dense form.
385
386    Note:
387        Only 2-D CSRTensor is supported for now.
388
389    Args:
390        csr_tensor (CSRTensor): A CSRTensor, must be 2-D.
391
392    Returns:
393        Tensor.
394
395    Raises:
396        TypeError: If input is not a CSRTensor.
397        ValueError: If input CSRTensor is not 2-D.
398
399    Supported Platforms:
400        ``GPU``
401
402    Examples:
403        >>> from mindspore import Tensor, CSRTensor, ops
404        >>> indptr = Tensor([0, 1, 2]).astype("int32")
405        >>> indices = Tensor([0, 1]).astype("int32")
406        >>> values = Tensor([2, 1]).astype("float32")
407        >>> shape = (2, 4)
408        >>> x = CSRTensor(indptr, indices, values, shape)
409        >>> output = ops.csr_to_dense(x)
410        >>> print(output)
411    """
412    if not isinstance(csr_tensor, CSRTensor):
413        raise_type_error("For functional operator csr_to_dense, input argument must be a CSRTensor.")
414    if len(csr_tensor.shape) > 2:
415        raise_value_error("Currently only support 2-D CSRTensor when converting to COOTensor.")
416
417    shape = _convert_shape(csr_tensor.shape)
418    dense_shape = Tensor(shape, dtype=mstype.int32)
419    batch_pointers = ops.concat((make_tensor([0]), ops.TensorShape()(csr_tensor.values))).astype("int32")
420    row_pointers = csr_tensor.indptr
421    col_indices = csr_tensor.indices
422    values = csr_tensor.values
423
424    valid_indices_dtype = [mstype.int32, mstype.int64]
425    if row_pointers.dtype in valid_indices_dtype and col_indices.dtype in valid_indices_dtype:
426        if mstype.int64 in (row_pointers.dtype, col_indices.dtype):
427            return csr_sparse_matrix_to_dense(dense_shape.astype(mstype.int64), batch_pointers.astype(mstype.int64),
428                                              row_pointers.astype(mstype.int64), col_indices.astype(mstype.int64),
429                                              values)
430
431    return csr_sparse_matrix_to_dense(dense_shape, batch_pointers, row_pointers, col_indices, values)
432
433
434# deprecated, will be removed once `csr_to_coo` supports all backends.
435csr2coo = _csr_ops.CSR2COO()
436
437csr_tensor_get_dense_shape = Primitive('CSRTensorGetDenseShape')
438
439csr_tensor_get_indices = Primitive('CSRTensorGetIndices')
440
441csr_tensor_get_indptr = Primitive('CSRTensorGetIndptr')
442
443csr_tensor_get_values = Primitive('CSRTensorGetValues')
444
445
446def dense_to_sparse_coo(tensor: Tensor) -> COOTensor:
447    """
448    Convert a Tensor to COOTensor.
449
450    Note:
451        Only 2-D tensor is supported for now.
452
453    Args:
454        tensor (Tensor): A dense tensor, must be 2-D.
455
456    Returns:
457        COOTensor, a sparse representation of the original dense tensor, containing the following parts.
458
459        - indices (Tensor): 2-D integer tensor, indicates the positions of `values` of the dense tensor.
460        - values (Tensor): 1-D tensor, indicates the non-zero values of the dense tensor.
461        - shape (tuple(int)): the shape of the COOTensor, is the same as the original dense tensor.
462
463    Raises:
464        TypeError: If input is not a tensor.
465        ValueError: If input tensor is not 2-D.
466
467    Supported Platforms:
468        ``GPU``
469
470    Examples:
471        >>> from mindspore import Tensor, ops
472        >>> import mindspore as ms
473        >>> x = Tensor([[1, 0], [-5, 0]], ms.float32)
474        >>> output = ops.dense_to_sparse_coo(x)
475        >>> print(output.indices)
476        [[0 0]
477        [1 0]]
478        >>> print(output.values)
479        [ 1. -5.]
480        >>> print(output.shape)
481        (2, 2)
482    """
483    if not isinstance(tensor, Tensor):
484        raise_type_error("For functional operator dense_to_sparse_coo, input argument must be a Tensor.")
485    if len(_convert_shape(tensor.shape)) > 2:
486        raise_value_error("Currently only support 2-D Tensor when converting to COOTensor.")
487    indices = tensor.nonzero().astype("int32")
488    values = gather_nd(tensor, indices)
489    return COOTensor(indices, values, _convert_shape(tensor.shape))
490
491
492def dense_to_sparse_csr(tensor: Tensor) -> CSRTensor:
493    """
494    Convert a Tensor to CSRTensor.
495
496    Note:
497        Only 2-D tensor is supported for now.
498
499    Args:
500        tensor (Tensor): A dense tensor, must be 2-D.
501
502    Returns:
503        CSRTensor, a sparse representation of the original dense tensor, containing the following parts.
504
505        - indptr (Tensor): 1-D integer tensor, indicates the start and end point for `values` in each row.
506        - indices (Tensor): 1-D integer tensor, indicates the column positions of all non-zero values of the input.
507        - values (Tensor): 1-D tensor, indicates the non-zero values of the dense tensor.
508        - shape (tuple(int)): the shape of the CSRTensor, is the same as the original dense tensor.
509
510    Raises:
511        TypeError: If input is not a tensor.
512        ValueError: If input tensor is not 2-D.
513
514    Supported Platforms:
515        ``GPU``
516
517    Examples:
518        >>> from mindspore import Tensor, ops
519        >>> import mindspore as ms
520        >>> x = Tensor([[1, 0], [-5, 0]], ms.float32)
521        >>> output = ops.dense_to_sparse_csr(x)
522        >>> print(output.indptr)
523        [0 1 2]
524        >>> print(output.indices)
525        [0 0]
526        >>> print(output.shape)
527        (2, 2)
528    """
529    if not isinstance(tensor, Tensor):
530        raise_type_error("For functional operator dense_to_sparse_csr, input argument must be a Tensor.")
531    if len(_convert_shape(tensor.shape)) > 2:
532        raise_value_error("Currently only support 2-D Tensor when converting to CSRTensor.")
533    indices = tensor.nonzero().astype("int32")
534    _, _, indptr, indices, values = dense_to_csr(tensor, indices)
535    return CSRTensor(indptr, indices, values, _convert_shape(tensor.shape))
536
537
538def make_sparse_tensor(indices, values, dense_shape):
539    """Call make_coo_tensor in this function."""
540    print_info("WARNING: 'SparseTensor' is deprecated from version 1.7 and will be removed in a future version. " +
541               "Please use 'COOTensor' instead.")
542    return make_coo_tensor(indices, values, dense_shape)
543
544
545def make_row_tensor(indices, values, dense_shape):
546    """Call make_row_tensor_inner in this function."""
547    print_info("WARNING: 'RowTensor' is deprecated from version 2.0 and will be removed in a future version.")
548    return make_row_tensor_inner(indices, values, dense_shape)
549
550
551make_coo_tensor = Primitive('MakeCOOTensor')
552
553make_csr_tensor = Primitive('MakeCSRTensor')
554
555make_row_tensor_inner = Primitive('MakeRowTensor')
556
557make_map_parameter = Primitive('MakeMapParameter')
558
559row_tensor_get_values = Primitive('RowTensorGetValues')
560
561row_tensor_get_indices = Primitive('RowTensorGetIndices')
562
563row_tensor_get_dense_shape = Primitive('RowTensorGetDenseShape')
564
565row_tensor_add = Primitive('RowTensorAdd')
566
567
568@constexpr
569def _calc_out_shape(sp_input, concat_dim):
570    "calculating the COOTensor output shape in coo_concat"
571    if isinstance(sp_input[0], tuple):
572        out_shape_list = list(sp_input[0][2])
573    else:
574        out_shape_list = list(sp_input[0].shape)
575    for i in range(1, len(sp_input)):
576        if isinstance(sp_input[i], tuple):
577            out_shape_list[concat_dim] += sp_input[i][2][concat_dim]
578        else:
579            out_shape_list[concat_dim] += sp_input[i].shape[concat_dim]
580    return tuple(out_shape_list)
581
582
583@constexpr
584def _set_coo_concat_input(sp_input):
585    "split COOTensor to normal tensor"
586    if len(sp_input) < 2:
587        raise_value_error("For coo_concat, not support COOTensor input number < 2.")
588    in_indices = []
589    in_values = []
590    in_shapes = []
591    for element in sp_input:
592        if isinstance(element, tuple):
593            in_indices.append(element[0])
594            in_values.append(element[1])
595            in_shapes.append(Tensor(element[2], dtype=mstype.int64))
596        else:
597            in_indices.append(element.indices)
598            in_values.append(element.values)
599            in_shapes.append(Tensor(element.shape, dtype=mstype.int64))
600    return in_indices, in_values, in_shapes
601
602
603def coo_concat(sp_input, concat_dim=0):
604    """
605    concatenates the input SparseTensor(COO format) along the specified dimension.
606
607    .. warning::
608        This is an experimental API that is subjected to change or deletion. Only supported on CPU now.
609
610    Args:
611        sp_input (Union[list(COOTensor), tuple(COOTensor)]) - the list of SparseTensor which need to concatenates.
612            for COOTensor input.
613        concat_dim (scalar): decide the dimension to concatenation along.
614            The value must be in range [-rank, rank), where rank is the number of dimensions in each input
615            SparseTensor. Default is 0.
616
617    Returns:
618        - **output** (COOtensor) - the result of concatenates the input SparseTensor along the
619          specified dimension. OutShape: OutShape[non concat_dim] is equal to InShape[non concat_dim] and
620          OutShape[concat_dim] is all input concat_dim axis shape accumulate.
621
622    Raises:
623        ValueError: If only one sparse tensor input.
624        ValueError: If Input COOTensor shape dim > 3. COOtensor shape dim size must be 2 now.
625
626    Supported Platforms:
627        ``CPU``
628
629    Examples:
630        >>> from mindspore import Tensor, ops, COOTensor
631        >>> from mindspore import dtype as mstype
632        >>> indices0 = Tensor([[0, 1], [1, 2]], dtype=mstype.int64)
633        >>> values0 = Tensor([1, 2], dtype=mstype.int32)
634        >>> shape0 = (3, 4)
635        >>> input0 = COOTensor(indices0, values0, shape0)
636        >>> indices1 = Tensor([[0, 0], [1, 1]], dtype=mstype.int64)
637        >>> values1 = Tensor([3, 4], dtype=mstype.int32)
638        >>> shape1 = (3, 4)
639        >>> input1 = COOTensor(indices1, values1, shape1)
640        >>> concat_dim = 1
641        >>> out = ops.coo_concat((input0, input1), concat_dim)
642        >>> print(out)
643        COOTensor(shape=[3, 8], dtype=Int32, indices=Tensor(shape=[4, 2], dtype=Int64, value=
644        [[0 1]
645         [0 4]
646         [1 2]
647         [1 5]]), values=Tensor(shape=[4], dtype=Int32, value=[1 3 2 4]))
648    """
649    coo_concat_op = SparseConcat(concat_dim)
650    in_indices, in_values, in_shapes = _set_coo_concat_input(sp_input)
651    indices, values, _ = coo_concat_op(in_indices, in_values, in_shapes)
652    out_shape = _calc_out_shape(sp_input, concat_dim)
653    return COOTensor(indices, values, out_shape)
654
655
656def coo_add(x1: COOTensor, x2: COOTensor, thresh: Tensor) -> COOTensor:
657    """
658    Computes the sum of x1(COOTensor) and x2(COOTensor), and return a new COOTensor
659    based on the computed result and `thresh`.
660
661    Args:
662        x1 (COOTensor): the first COOTensor to sum.
663        x2 (COOTensor): the second COOTensor to sum.
664        thresh (Tensor): A 0-D Tensor, represents the magnitude threshold that determines
665            if an output value/index pair take place. Its dtype
666            should match that of the values if they are real. If output's
667            value is less than the `thresh`, it will vanish.
668
669    Returns:
670        A COOTensor, the result of sum.
671
672    Raises:
673        ValueError: If any input(x1/x2)'s indices's dim is not equal to 2.
674        ValueError: If any input(x1/x2)'s values's dim is not equal to 1.
675        ValueError: If any input(x1/x2)'s shape's dim is not equal to 1.
676        ValueError: If thresh's dim is not equal to 0.
677        TypeError: If any input(x1/x2)'s indices's type is not equal to int64.
678        TypeError: If any input(x1/x2)'s shape's type is not equal to int64.
679        ValueError: If any input(x1/x2)'s indices's length is not equal to
680            its values's length.
681        TypeError: If any input(x1/x2)'s values's type is not equal to anf of
682            (int8/int16/int32/int64/float32/float64/complex64/complex128).
683        TypeError: If thresh's type is not equal to anf of
684            (int8/int16/int32/int64/float32/float64).
685        TypeError: If x1's indices's type is not equal to x2's indices's type.
686        TypeError: If x1's values's type is not equal to x2's values's type.
687        TypeError: If x1's shape's type is not equal to x2's shape's type.
688        TypeError: If (x1/x2)'s value's type is not matched with thresh's type.
689
690    Supported Platforms:
691        ``GPU`` ``CPU``
692
693    Examples:
694        >>> from mindspore import Tensor, COOTensor
695        >>> from mindspore import dtype as mstype
696        >>> from mindspore import context
697        >>> from mindspore import ops
698        >>> indics0 = Tensor([[0, 1], [1, 2]], dtype=mstype.int64)
699        >>> values0 = Tensor([1, 2], dtype=mstype.int32)
700        >>> shape0 = (3, 4)
701        >>> input0 = COOTensor(indics0, values0, shape0)
702        >>> indics1 = Tensor([[0, 0], [1, 1]], dtype=mstype.int64)
703        >>> values1 = Tensor([3, 4], dtype=mstype.int32)
704        >>> shape1 = (3, 4)
705        >>> input1 = COOTensor(indics1, values1, shape1)
706        >>> thres = Tensor(0, dtype=mstype.int32)
707        >>> out = ops.coo_add(input0, input1, thres)
708        >>> print(out)
709        COOTensor(shape=[3, 4], dtype=Int32, indices=Tensor(shape=[4, 2], dtype=Int64, value=
710        [[0 0]
711         [0 1]
712         [1 1]
713         [1 2]]), values=Tensor(shape=[4], dtype=Int32, value=[3 1 4 2]))
714    """
715    x1, x2 = promote_coo(x1, x2)
716    thresh = thresh.astype(x1.dtype)
717    x1_indices = x1.indices
718    x1_values = x1.values
719    x2_indices = x2.indices
720    x2_values = x2.values
721    den_shp = make_tensor(x1.shape, x1_indices.dtype)
722    add_op = SparseAdd()
723    indices, values, _ = add_op(x1_indices, x1_values, den_shp, x2_indices, x2_values, den_shp, thresh)
724    return COOTensor(indices, values, x1.shape)
725
726
727def csr_softmax(logits: CSRTensor, dtype: mstype):
728    """
729    Calculates the softmax of a CSRTensorMatrix.
730
731    Args:
732        logits (CSRTensor): Input sparse CSRTensor.
733        dtype (dtype): Input data type.
734
735    Returns:
736        CSRTensor, a CSRTensor containing
737
738        - **indptr** - Indicates the start and end point for non-zero values in each row.
739        - **indices** - The column positions of all non-zero values of the input.
740        - **values** - The non-zero values of the dense tensor.
741        - **shape** - The shape of the CSRTensor.
742
743    Supported Platforms:
744        ``GPU`` ``CPU``
745
746    Examples:
747        >>> import mindspore as ms
748        >>> from mindspore import ops
749        >>> import mindspore.common.dtype as mstype
750        >>> from mindspore import Tensor, CSRTensor
751        >>> logits_indptr = Tensor([0, 4, 6], dtype=mstype.int32)
752        >>> logits_indices = Tensor([0, 2, 3, 4, 3, 4], dtype=mstype.int32)
753        >>> logits_values = Tensor([1, 2, 3, 4, 1, 2], dtype=mstype.float32)
754        >>> shape = (2, 6)
755        >>> logits = CSRTensor(logits_indptr, logits_indices, logits_values, shape)
756        >>> out = ops.csr_softmax(logits, dtype=mstype.float32)
757        >>> print(out)
758        CSRTensor(shape=[2, 6], dtype=Float32, indptr=Tensor(shape=[3], dtype=Int32, value=[0 4 6]),
759                       indices=Tensor(shape=[6], dtype=Int32, value=[0 2 3 4 3 4]),
760                       values=Tensor(shape=[6], dtype=Float32, value=[ 3.20586003e-02  8.71443152e-02  2.36882806e-01
761                       6.43914223e-01  2.68941432e-01  7.31058598e-01]))
762    """
763    if not isinstance(logits, CSRTensor):
764        raise_type_error("For functional operator sparse_matrix_softmax, logits must be type of CSRTensor.")
765    sparse_matrix_softmax_op = SparseMatrixSoftmax(dtype)
766    logits_batch_pointers = make_tensor([0, logits.values.shape[0]], mstype.int32)
767    logits_shape = make_tensor(logits.shape, mstype.int32)
768    shape, _, indptr, indices, values = sparse_matrix_softmax_op(logits_shape, logits_batch_pointers, logits.indptr,
769                                                                 logits.indices, logits.values)
770    output_shape = tuple(shape.asnumpy().tolist())
771    return CSRTensor(indptr=indptr, indices=indices, values=values, shape=output_shape)
772
773
774def csr_add(a: CSRTensor, b: CSRTensor, alpha: Tensor, beta: Tensor) -> CSRTensor:
775    """
776    Computes the linear combination of two input CSRTensors a and b.
777
778    .. math::
779
780        out = alpha * a + beta * b
781
782    where both :math:`a` and :math:`b` are CSRTensor, :math:`alpha` and :math:`beta` are both Tensor
783
784    Note:
785        The user need to ensure that the input sparse matrix is legal.
786        Otherwise, the behavior of the operator is undefined.
787        For example, when there are multiple elements in the same position, the
788        operator may return an error of fail execute.
789
790    Args:
791        a (CSRTensor): Input sparse CSRTensor.
792        b (CSRTensor): Input sparse CSRTensor.
793        alpha(Tensor): Dense Tensor, its shape must be able to broadcast to a.
794        beta(Tensor): Dense Tensor, its shape must be able to broadcast to b.
795
796    Returns:
797        CSRTensor, a CSRTensor containing the following parts.
798
799        - **indptr** -  Indicates the start and end point for non-zero values in each row.
800        - **indices** - The column positions of all non-zero values of the input.
801        - **values** - The non-zero values of the dense tensor.
802        - **shape** - The shape of the CSRTensor.
803
804    Supported Platforms:
805        ``GPU`` ``CPU``
806
807    Examples:
808        >>> import mindspore.common.dtype as mstype
809        >>> from mindspore import Tensor, CSRTensor
810        >>> from mindspore import ops
811        >>> a_indptr = Tensor([0, 1, 2], dtype=mstype.int32)
812        >>> a_indices = Tensor([0, 1], dtype=mstype.int32)
813        >>> a_values = Tensor([1, 2], dtype=mstype.float32)
814        >>> shape = (2, 6)
815        >>> b_indptr = Tensor([0, 1, 2], dtype=mstype.int32)
816        >>> b_indices = Tensor([0, 1], dtype=mstype.int32)
817        >>> b_values = Tensor([1, 2], dtype=mstype.float32)
818        >>> alpha = Tensor(1, mstype.float32)
819        >>> beta = Tensor(1, mstype.float32)
820        >>> csra = CSRTensor(a_indptr, a_indices, a_values, shape)
821        >>> csrb = CSRTensor(b_indptr, b_indices, b_values, shape)
822        >>> out = ops.csr_add(csra, csrb, alpha, beta)
823        >>> print(out)
824        CSRTensor(shape=[2, 6], dtype=Float32, \
825                  indptr=Tensor(shape=[3], dtype=Int32, value=[0 1 2]), \
826                  indices=Tensor(shape=[2], dtype=Int32, value=[0 1]), \
827                  values=Tensor(shape=[2], dtype=Float32, value=[ 2.00000000e+00  4.00000000e+00]))
828    """
829    if not isinstance(a, CSRTensor) or not isinstance(b, CSRTensor):
830        raise_type_error("For functional operator csr_add, both inputs a and b must be type of CSRTensor.")
831    if not isinstance(alpha, Tensor) or not isinstance(beta, Tensor):
832        raise_type_error("For functional operator csr_add, both inputs alpha and beta must be Tensor.")
833    csr_add_op = SparseMatrixAdd()
834    a, b = promote_csr(a, b)
835    alpha = alpha.astype(a.dtype)
836    beta = beta.astype(a.dtype)
837    a_batch_pointers = ops.concat((make_tensor([0]), ops.TensorShape()(a.values))).astype(a.indptr.dtype)
838    b_batch_pointers = ops.concat((make_tensor([0]), ops.TensorShape()(b.values))).astype(b.indptr.dtype)
839    a_shape = make_tensor(a.shape, a.indptr.dtype)
840    b_shape = make_tensor(b.shape, b.indptr.dtype)
841    _, _, indptr, indices, values = csr_add_op(a_shape, a_batch_pointers, a.indptr, a.indices, a.values,
842                                               b_shape, b_batch_pointers, b.indptr, b.indices, b.values,
843                                               alpha, beta)
844    return CSRTensor(indptr, indices, values, a.shape)
845
846
847__all__ = [
848    'coalesce',
849    'coo2csr',
850    'coo_tensor_get_dense_shape',
851    'coo_tensor_get_indices',
852    'coo_tensor_get_values',
853    'csr_div',
854    'csr_gather',
855    'csr_mm',
856    'csr_mul',
857    'csr_mv',
858    'csr_reduce_sum',
859    'csr_to_coo',
860    'csr2coo',
861    'csr_tensor_get_dense_shape',
862    'csr_tensor_get_indices',
863    'csr_tensor_get_indptr',
864    'csr_tensor_get_values',
865    'dense_to_sparse_coo',
866    'dense_to_sparse_csr',
867    'make_sparse_tensor',
868    'make_coo_tensor',
869    'make_csr_tensor',
870    'make_row_tensor',
871    'make_row_tensor_inner',
872    'make_map_parameter',
873    'row_tensor_get_values',
874    'row_tensor_get_indices',
875    'row_tensor_get_dense_shape',
876    'row_tensor_add',
877    'coo_add',
878    'coo_concat',
879    'csr_add',
880    'csr_softmax',
881    'csr_to_dense'
882]
883
884__all__.sort()
885