• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""smoke tests for CSR operations"""
16
17import os
18import pytest
19import numpy as np
20
21from mindspore import Tensor, CSRTensor, jit, nn, ops
22from mindspore.common import dtype as mstype
23from mindspore.train.serialization import export, load
24from mindspore.ops import functional as F
25from mindspore.ops.operations import _csr_ops
26
27from .sparse_utils import get_platform, compare_res, compare_csr
28
29
30@pytest.mark.level1
31@pytest.mark.platform_arm_ascend_training
32@pytest.mark.platform_x86_ascend_training
33@pytest.mark.platform_x86_gpu_training
34@pytest.mark.platform_x86_cpu
35@pytest.mark.env_onecard
36def test_make_csr():
37    """
38    Feature: Test CSRTensor Constructor in Graph and PyNative.
39    Description: Test CSRTensor(indptr, indices, values, shape) and CSRTensor(CSRTensor)
40    Expectation: Success.
41    """
42    if get_platform() != "linux":
43        return
44    indptr = Tensor([0, 1, 2])
45    indices = Tensor([0, 1])
46    values = Tensor([1, 2], dtype=mstype.float32)
47    shape = (2, 6)
48
49    def test_pynative():
50        return CSRTensor(indptr, indices, values, shape)
51    test_graph = jit(test_pynative)
52
53    csr1 = test_pynative()
54    csr2 = test_graph()
55    compare_csr(csr1, csr2)
56    csr3 = CSRTensor(csr_tensor=csr2)
57    compare_csr(csr3, csr2)
58
59
60@pytest.mark.level1
61@pytest.mark.platform_arm_ascend_training
62@pytest.mark.platform_x86_ascend_training
63@pytest.mark.platform_x86_gpu_training
64@pytest.mark.platform_x86_cpu
65@pytest.mark.env_onecard
66def test_csr_attr():
67    """
68    Feature: Test CSRTensor GetAttr in Graph and PyNative.
69    Description: Test CSRTensor.indptr, CSRTensor.indices, CSRTensor.values, CSRTensor.shape.
70    Expectation: Success.
71    """
72    if get_platform() != "linux":
73        return
74    indptr = Tensor([0, 1, 2])
75    indices = Tensor([0, 1])
76    values = Tensor([1, 2], dtype=mstype.float32)
77    shape = (2, 6)
78    csr = CSRTensor(indptr, indices, values, shape)
79
80    def test_pynative_1():
81        return csr.indptr, csr.indices
82
83    def test_pynative_2():
84        return csr.values, csr.shape
85
86    def test_pynative_3():
87        return csr.astype(mstype.int32)
88
89    def test_pynative_4():
90        return csr.to_tuple()
91
92    test_graph_1 = jit(test_pynative_1)
93    test_graph_2 = jit(test_pynative_2)
94    test_graph_3 = jit(test_pynative_3)
95    test_graph_4 = jit(test_pynative_4)
96
97    py_indptr, py_indices = test_pynative_1()
98    py_values, py_shape = test_pynative_2()
99    py_csr = test_pynative_3()
100    py_tuple = test_pynative_4()
101
102    g_indptr, g_indices = test_graph_1()
103    g_values, g_shape = test_graph_2()
104    g_csr = test_graph_3()
105    g_tuple = test_graph_4()
106
107    csr1 = CSRTensor(py_indptr, py_indices, py_values, py_shape)
108    csr2 = CSRTensor(g_indptr, g_indices, g_values, g_shape)
109    # check csr attr
110    compare_csr(csr1, csr2)
111    # check astype
112    compare_csr(py_csr, g_csr)
113    # check to_tuple
114    assert len(py_tuple) == len(g_tuple)
115    for i, _ in enumerate(py_tuple):
116        if isinstance(py_tuple[i], Tensor):
117            assert (py_tuple[i].asnumpy() == g_tuple[i].asnumpy()).all()
118        else:
119            assert py_tuple[i] == g_tuple[i]
120
121
122@pytest.mark.level0
123@pytest.mark.platform_arm_ascend_training
124@pytest.mark.platform_x86_ascend_training
125@pytest.mark.platform_x86_gpu_training
126@pytest.mark.env_onecard
127def test_csr_tensor_in_while():
128    """
129    Feature: Test CSRTensor in while loop.
130    Description: Test CSRTensor computation in while loop.
131    Expectation: Success.
132    """
133    class CSRTensorValuesDouble(nn.Cell):
134
135        def construct(self, x):
136            indptr = x.indptr
137            indices = x.indices
138            values = x.values * 2
139            shape = x.shape
140            return CSRTensor(indptr, indices, values, shape)
141
142    class CSRTensorValuesAdd2(nn.Cell):
143
144        def construct(self, x):
145            indptr = x.indptr
146            indices = x.indices
147            values = x.values + 2
148            shape = x.shape
149            return CSRTensor(indptr, indices, values, shape)
150
151    class CSRTensorWithControlWhile(nn.Cell):
152        def __init__(self, shape):
153            super(CSRTensorWithControlWhile, self).__init__()
154            self.op1 = CSRTensorValuesDouble()
155            self.op2 = CSRTensorValuesAdd2()
156            self.shape = shape
157
158        @jit
159        def construct(self, a, b, indptr, indices, values):
160            x = CSRTensor(indptr, indices, values, self.shape)
161            x = self.op2(x)
162            while a > b:
163                x = self.op1(x)
164                b = b + 1
165            return x
166    a = Tensor(3, mstype.int32)
167    b = Tensor(0, mstype.int32)
168    indptr = Tensor([0, 1, 2])
169    indices = Tensor([0, 1])
170    values = Tensor([1, 2], dtype=mstype.float32)
171    shape = (2, 6)
172    net = CSRTensorWithControlWhile(shape)
173    out = net(a, b, indptr, indices, values)
174    assert np.allclose(out.indptr.asnumpy(), indptr.asnumpy(), .0, .0)
175    assert np.allclose(out.indices.asnumpy(), indices.asnumpy(), .0, .0)
176    assert np.allclose((values.asnumpy() + 2) * 8, out.values.asnumpy(), .0, .0)
177    assert shape == out.shape
178
179    # Test Export MindIR
180    file_name = "csrtensor_with_control_while_net"
181    export(net, a, b, indptr, indices, values, file_name=file_name, file_format="MINDIR")
182    mindir_name = file_name + ".mindir"
183    assert os.path.exists(mindir_name)
184
185    graph = load(mindir_name)
186    loaded_net = nn.GraphCell(graph)
187    outputs_after_load = loaded_net(a, b, indptr, indices, values)
188    assert np.allclose(out.indptr.asnumpy(), outputs_after_load.indptr.asnumpy())
189    assert np.allclose(out.indices.asnumpy(), outputs_after_load.indices.asnumpy())
190    assert np.allclose(out.values.asnumpy(), outputs_after_load.values.asnumpy())
191    assert out.shape == outputs_after_load.shape
192
193
194@pytest.mark.level2
195@pytest.mark.platform_x86_cpu
196@pytest.mark.env_onecard
197def test_csr_tensor_in_while_cpu():
198    """
199    Feature: Test CSRTensor in while loop.
200    Description: Test CSRTensor computation in while loop.
201    Expectation: Success.
202    """
203    class CSRTensorValuesDouble(nn.Cell):
204
205        def construct(self, x):
206            indptr = x.indptr
207            indices = x.indices
208            values = x.values * 2
209            shape = x.shape
210            return CSRTensor(indptr, indices, values, shape)
211
212    class CSRTensorValuesAdd2(nn.Cell):
213
214        def construct(self, x):
215            indptr = x.indptr
216            indices = x.indices
217            values = x.values + 2
218            shape = x.shape
219            return CSRTensor(indptr, indices, values, shape)
220
221    class CSRTensorWithControlWhile(nn.Cell):
222        def __init__(self, shape):
223            super(CSRTensorWithControlWhile, self).__init__()
224            self.op1 = CSRTensorValuesDouble()
225            self.op2 = CSRTensorValuesAdd2()
226            self.shape = shape
227
228        @jit
229        def construct(self, a, b, indptr, indices, values):
230            x = CSRTensor(indptr, indices, values, self.shape)
231            x = self.op2(x)
232            while a > b:
233                x = self.op1(x)
234                b = b + 1
235            return x
236    a = Tensor(3, mstype.int32)
237    b = Tensor(0, mstype.int32)
238    indptr = Tensor([0, 1, 2])
239    indices = Tensor([0, 1])
240    values = Tensor([1, 2], dtype=mstype.float32)
241    shape = (2, 6)
242    net = CSRTensorWithControlWhile(shape)
243    out = net(a, b, indptr, indices, values)
244    assert np.allclose(out.indptr.asnumpy(), indptr.asnumpy(), .0, .0)
245    assert np.allclose(out.indices.asnumpy(), indices.asnumpy(), .0, .0)
246    assert np.allclose((values.asnumpy() + 2) * 8, out.values.asnumpy(), .0, .0)
247    assert shape == out.shape
248
249
250@pytest.mark.level1
251@pytest.mark.platform_x86_gpu_training
252@pytest.mark.platform_x86_cpu
253@pytest.mark.env_onecard
254def test_batch_csr_ops():
255    """
256    Feature: Test Batch CSR-related Ops.
257    Description: Test CSRReduceSum, CSRMul, CSRGather.
258    Expectation: Success.
259    """
260    if get_platform() != "linux":
261        return
262    csr_gather = _csr_ops.CSRGather()
263
264    indptr = Tensor([0, 1, 1, 2, 2], dtype=mstype.int32)
265    indices = Tensor([0, 1], dtype=mstype.int32)
266    values = Tensor([[2, 1, 3], [2, 1, 3]], dtype=mstype.float32)
267    dense_shape = (4, 2, 3)
268    dense_tensor = Tensor(
269        [[[1, 1, 1], [2, 2, 2]], [[1, 1, 1], [2, 2, 2]], [[1, 1, 1], [2, 2, 2]], [[1, 1, 1], [2, 2, 2]]],
270        dtype=mstype.float32)
271    csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
272
273    def test_ops_pynative_gather():
274        dense = csr_gather(indptr, indices, dense_tensor, dense_shape)
275        return dense
276
277    def test_ops_pynative_reducesum():
278        dense = F.csr_reduce_sum(csr_tensor, 1)
279        return dense
280
281    def test_ops_pynative_sparse_elemwise():
282        sparse1 = csr_tensor * dense_tensor
283        sparse2 = csr_tensor / dense_tensor
284        return sparse1, sparse2
285
286    # TODO(PyTrace): PyTrace Async bug.
287    test_ops_graph_reducesum = jit(test_ops_pynative_reducesum)
288    graph_res_reducesum = test_ops_graph_reducesum()
289    res_reducesum = test_ops_pynative_reducesum()
290    expect1 = np.array([[2., 1., 3.]], dtype=np.float32)
291    expect2 = np.array([[2., 1., 3.]], dtype=np.float32)
292    assert np.allclose(res_reducesum[0].asnumpy(), expect1)
293    assert np.allclose(res_reducesum[2].asnumpy(), expect2)
294    assert np.allclose(graph_res_reducesum[0].asnumpy(), expect1)
295    assert np.allclose(graph_res_reducesum[2].asnumpy(), expect2)
296
297    # TODO(PyTrace): PyTrace Async bug.
298    test_ops_graph_elemwise = jit(test_ops_pynative_sparse_elemwise)
299    graph_res_elemwise = test_ops_graph_elemwise()
300    res_elemwise = test_ops_pynative_sparse_elemwise()
301    expect3 = np.array([[2., 1., 3.], [4., 2., 6.]], dtype=np.float32)
302    expect4 = np.array([[2., 1., 3.], [1., 0.5, 1.5]], dtype=np.float32)
303    assert np.allclose(res_elemwise[0].values.asnumpy(), expect3)
304    assert np.allclose(res_elemwise[1].values.asnumpy(), expect4)
305    assert np.allclose(graph_res_elemwise[0].values.asnumpy(), expect3)
306    assert np.allclose(graph_res_elemwise[1].values.asnumpy(), expect4)
307
308    expect5 = np.array([[1, 1, 1], [2, 2, 2]], dtype=np.float32)
309    res_gather = test_ops_pynative_gather()
310    test_ops_graph_gather = jit(test_ops_pynative_gather)
311    graph_res_gather = test_ops_graph_gather()
312    assert np.allclose(res_gather.asnumpy(), expect5)
313    assert np.allclose(graph_res_gather.asnumpy(), expect5)
314
315
316@pytest.mark.level1
317@pytest.mark.platform_x86_gpu_training
318@pytest.mark.platform_x86_cpu
319@pytest.mark.env_onecard
320def test_csr_ops():
321    """
322    Feature: Test CSR-related Ops.
323    Description: Test CSRReduceSum, CSRMul, CSRMV.
324    Expectation: Success.
325    """
326    if get_platform() != "linux":
327        return
328
329    indptr = Tensor([0, 1, 2], dtype=mstype.int32)
330    indices = Tensor([0, 1], dtype=mstype.int32)
331    values = Tensor([2, 1], dtype=mstype.float32)
332    dense_shape = (2, 4)
333
334    dense_tensor = Tensor([[1., 1, 1, 1], [1, 1, 1, 1]], dtype=mstype.float32)
335    dense_vector = Tensor([[1.], [1], [1], [1]], dtype=mstype.float32)
336    csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
337    dense_matrix = Tensor([[1., 2.], [1, 2.], [1, 2.], [1., 2.]], dtype=mstype.float32)
338
339    def test_ops_pynative_dense():
340        dense1 = F.csr_reduce_sum(csr_tensor, 1)
341        dense2 = F.csr_mv(csr_tensor, dense_vector)
342        dense3 = csr_tensor.mm(dense_matrix)
343        return dense1, dense2, dense3
344
345    def test_ops_pynative_sparse():
346        sparse1 = csr_tensor * dense_tensor
347        sparse2 = dense_tensor * csr_tensor
348        sparse3 = csr_tensor / dense_tensor
349        return sparse1, sparse2, sparse3
350
351    test_ops_graph_dense = jit(test_ops_pynative_dense)
352    test_ops_graph_sparse = jit(test_ops_pynative_sparse)
353
354    # TODO(PyTrace): PyTrace async bug.
355    graph_res_dense = test_ops_graph_dense()
356    pynative_res_dense = test_ops_pynative_dense()
357    expect1 = np.array([[2.], [1.]], dtype=np.float32)
358    expect2 = np.array([[2.], [1.]], dtype=np.float32)
359    expect3 = np.array([[2., 4.], [1., 2.]], dtype=np.float32)
360    assert np.allclose(pynative_res_dense[0].asnumpy(), expect1)
361    assert np.allclose(pynative_res_dense[1].asnumpy(), expect2)
362    assert np.allclose(pynative_res_dense[2].asnumpy(), expect3)
363    assert np.allclose(graph_res_dense[0].asnumpy(), expect1)
364    assert np.allclose(graph_res_dense[1].asnumpy(), expect2)
365    assert np.allclose(graph_res_dense[2].asnumpy(), expect3)
366
367    # TODO(PyTrace): PyTrace async bug.
368    graph_res_sparse = test_ops_graph_sparse()
369    pynative_res_sparse = test_ops_pynative_sparse()
370    expect3 = np.array([2., 1.], dtype=np.float32)
371    assert np.allclose(pynative_res_sparse[0].values.asnumpy(), expect3)
372    assert np.allclose(pynative_res_sparse[1].values.asnumpy(), expect3)
373    assert np.allclose(pynative_res_sparse[2].values.asnumpy(), expect3)
374    assert np.allclose(graph_res_sparse[0].values.asnumpy(), expect3)
375    assert np.allclose(graph_res_sparse[1].values.asnumpy(), expect3)
376    assert np.allclose(graph_res_sparse[2].values.asnumpy(), expect3)
377
378
379@pytest.mark.level0
380@pytest.mark.platform_arm_ascend_training
381@pytest.mark.platform_x86_ascend_training
382@pytest.mark.platform_x86_gpu_training
383@pytest.mark.platform_x86_cpu
384@pytest.mark.env_onecard
385def test_csrtensor_export_and_import_mindir():
386    """
387    Feature: Test exporting and loading CSRTensor MindIR.
388    Description: Test export and load.
389    Expectation: Success.
390    """
391    if get_platform() != "linux":
392        return
393
394    class TestCSRTensor(nn.Cell):
395        def __init__(self, shape):
396            super(TestCSRTensor, self).__init__()
397            self.shape = shape
398
399        def construct(self, indptr, indices, values):
400            return CSRTensor(indptr, indices, values, self.shape)
401
402    indptr = Tensor([0, 1, 2])
403    indices = Tensor([0, 1])
404    values = Tensor([2, 1], dtype=mstype.float32)
405    shape = (2, 4)
406    net = TestCSRTensor(shape)
407
408    file_name = "csrtensor_net"
409    export(net, indptr, indices, values, file_name=file_name, file_format="MINDIR")
410    mindir_name = file_name + ".mindir"
411    assert os.path.exists(mindir_name)
412
413    out = net(indptr, indices, values)
414    graph = load(mindir_name)
415    loaded_net = nn.GraphCell(graph)
416    outputs_after_load = loaded_net(indptr, indices, values)
417    assert np.allclose(out.indptr.asnumpy(), outputs_after_load.indptr.asnumpy())
418    assert np.allclose(out.indices.asnumpy(), outputs_after_load.indices.asnumpy())
419    assert np.allclose(out.values.asnumpy(), outputs_after_load.values.asnumpy())
420    assert out.shape == outputs_after_load.shape
421
422
423@pytest.mark.level1
424@pytest.mark.platform_x86_gpu_training
425@pytest.mark.env_onecard
426def test_csrops_export_and_import_mindir():
427    """
428    Feature: Test exporting and loading CSRTensor MindIR in a net.
429    Description: Test export and load.
430    Expectation: Success.
431    """
432    class TestCSRNet(nn.Cell):
433        def __init__(self, shape):
434            super(TestCSRNet, self).__init__()
435            self.shape = shape
436
437        def construct(self, indptr, indices, values, dense_tensor, dense_vector):
438            csr_tensor = CSRTensor(indptr, indices, values, self.shape)
439            dense1 = F.csr_reduce_sum(csr_tensor, 1)
440            dense2 = F.csr_mv(csr_tensor, dense_vector)
441            dense3 = dense1 * dense2
442            sparse1 = csr_tensor * dense_tensor
443            sparse2 = dense_tensor * csr_tensor
444            return dense1, dense2, dense3, sparse1, sparse2
445
446    indptr = Tensor([0, 1, 2], dtype=mstype.int32)
447    indices = Tensor([0, 1], dtype=mstype.int32)
448    values = Tensor([2, 1], dtype=mstype.float32)
449    shape = (2, 4)
450    dense_tensor = Tensor([[1., 1, 1, 1], [1, 1, 1, 1]], dtype=mstype.float32)
451    dense_vector = Tensor([[1.], [1], [1], [1]], dtype=mstype.float32)
452
453    net = TestCSRNet(shape)
454    file_name = "csrops_net"
455    export(net, indptr, indices, values, dense_tensor, dense_vector, file_name=file_name, file_format="MINDIR")
456    mindir_name = file_name + ".mindir"
457    assert os.path.exists(mindir_name)
458
459    out = net(indptr, indices, values, dense_tensor, dense_vector)
460    expect0 = np.array([[2.], [1.]], dtype=np.float32)
461    expect1 = np.array([[2.], [1.]], dtype=np.float32)
462    expect2 = np.array([[4.], [1.]], dtype=np.float32)
463    expect3 = np.array([2., 1.], dtype=np.float32)
464    assert np.allclose(out[0].asnumpy(), expect0)
465    assert np.allclose(out[1].asnumpy(), expect1)
466    assert np.allclose(out[2].asnumpy(), expect2)
467    assert np.allclose(out[3].values.asnumpy(), expect3)
468    assert np.allclose(out[4].values.asnumpy(), expect3)
469
470    graph = load(mindir_name)
471    loaded_net = nn.GraphCell(graph)
472    outputs_after_load = loaded_net(indptr, indices, values, dense_tensor, dense_vector)
473    assert np.allclose(out[0].asnumpy(), outputs_after_load[0].asnumpy())
474    assert np.allclose(out[1].asnumpy(), outputs_after_load[1].asnumpy())
475    assert np.allclose(out[2].asnumpy(), outputs_after_load[2].asnumpy())
476    assert np.allclose(out[3].values.asnumpy(), outputs_after_load[3].values.asnumpy())
477    assert np.allclose(out[4].values.asnumpy(), outputs_after_load[4].values.asnumpy())
478    assert out[3].shape == outputs_after_load[3].shape
479    assert out[4].shape == outputs_after_load[4].shape
480
481
482@pytest.mark.level2
483@pytest.mark.platform_arm_ascend_training
484@pytest.mark.platform_x86_ascend_training
485@pytest.mark.platform_x86_gpu_training
486@pytest.mark.platform_x86_cpu
487@pytest.mark.env_onecard
488def test_isinstance_csr_tensor():
489    """
490    Feature: Test isinstance.
491    Description: Test: isinstance(x, CSRTensor).
492    Expectation: Success.
493    """
494    if get_platform() != "linux":
495        return
496    indptr = Tensor([0, 1, 2])
497    indices = Tensor([0, 1])
498    values = Tensor([2, 1], dtype=mstype.float32)
499    shape = (2, 4)
500
501    def pynative_test_csr_tensor():
502        x = CSRTensor(indptr, indices, values, shape)
503        # Test input CSRTensor
504        is_tensor = isinstance(x, Tensor)
505        is_bool = isinstance(x, bool)
506        is_float = isinstance(x, float)
507        is_tuple = isinstance(x, (Tensor, CSRTensor, int, float))
508        is_csr_tensor = isinstance(x, CSRTensor)
509
510        # Test input Tensor
511        is_tensor_2 = isinstance(indptr, CSRTensor)
512        is_tuple_2 = isinstance(indptr, (Tensor, CSRTensor))
513        return is_tensor, is_bool, is_float, is_tuple, is_csr_tensor, is_tensor_2, is_tuple_2
514    graph_test_csr_tensor = jit(pynative_test_csr_tensor)
515
516    out1 = pynative_test_csr_tensor()
517    out2 = graph_test_csr_tensor()
518    assert out1 == (False, False, False, True, True, False, True)
519    assert out2 == (False, False, False, True, True, False, True)
520
521
522@pytest.mark.level2
523@pytest.mark.platform_x86_gpu_training
524@pytest.mark.platform_x86_cpu
525@pytest.mark.env_onecard
526def test_dtype_csr_tensor():
527    """
528    Feature: Test F.dtype with CSRTensor.
529    Description: Test: F.dtype(x).
530    Expectation: Success.
531    """
532    if get_platform() != "linux":
533        return
534    indptr = Tensor([0, 1, 2])
535    indices = Tensor([0, 1])
536    values = Tensor([2, 1], dtype=mstype.float32)
537    shape = (2, 4)
538
539    def pynative_test():
540        x = CSRTensor(indptr, indices, values, shape)
541        return F.dtype(x), x.dtype
542    graph_test = jit(pynative_test)
543
544    out1, out2 = pynative_test()
545    out3, out4 = graph_test()
546    assert out1 in [mstype.float32]
547    assert out2 in [mstype.float32]
548    assert out3 in [mstype.float32]
549    assert out4 in [mstype.float32]
550
551
552@pytest.mark.level1
553@pytest.mark.platform_x86_gpu_training
554@pytest.mark.platform_x86_cpu
555@pytest.mark.env_onecard
556def test_bprop():
557    """
558    Feature: Test back-propagation with CSR-related Ops.
559    Description: Test CSRReduceSum, CSRMul, CSRDiv, CSRMV.
560    Expectation: Success.
561    """
562    if get_platform() != "linux":
563        return
564    grad_op = ops.GradOperation(get_all=True)
565
566    @grad_op
567    @jit
568    def test_csr_mul(indptr, indices, values, shape, dense):
569        csr_tensor = CSRTensor(indptr, indices, values, shape)
570        return csr_tensor * dense
571
572    @grad_op
573    @jit
574    def test_csr_div(indptr, indices, values, shape, dense):
575        csr_tensor = CSRTensor(indptr, indices, values, shape)
576        return csr_tensor / dense
577
578    @grad_op
579    @jit
580    def test_csr_reduce_sum(indptr, indices, values, shape, axis):
581        csr_tensor = CSRTensor(indptr, indices, values, shape)
582        return F.csr_reduce_sum(csr_tensor, axis)
583
584    @grad_op
585    @jit
586    def test_csrmv(indptr, indices, values, shape, dense):
587        csr_tensor = CSRTensor(indptr, indices, values, shape)
588        return F.csr_mv(csr_tensor, dense)
589
590    indptr = Tensor([0, 1, 4, 6], dtype=mstype.int32)
591    indices = Tensor([3, 0, 1, 2, 1, 3], dtype=mstype.int32)
592    values = Tensor(np.arange(6), dtype=mstype.float32)
593    dense_shape = (3, 4)
594
595    csr_mv_arg = Tensor([[1], [2], [3], [4]], dtype=mstype.float32)
596    csr_mv_expect_1 = np.array([4, 1, 2, 3, 2, 4], dtype=np.float32)
597    csr_mv_expect_2 = np.array([[1], [6], [3], [5]], dtype=np.float32)
598    csr_mv_output = test_csrmv(indptr, indices, values, dense_shape, csr_mv_arg)
599    # indptr, indices, values, dense_grad
600    assert len(csr_mv_output) == 4
601    assert np.allclose(csr_mv_output[2].asnumpy(), csr_mv_expect_1)
602    assert np.allclose(csr_mv_output[3].asnumpy(), csr_mv_expect_2)
603
604    csr_reduce_sum_expect_1 = np.ones(6, dtype=np.float32)
605    csr_reduce_sum_output_1 = test_csr_reduce_sum(indptr, indices, values, dense_shape, 1)
606    assert len(csr_reduce_sum_output_1) == 3
607    assert np.allclose(csr_reduce_sum_output_1[2].asnumpy(), csr_reduce_sum_expect_1)
608
609    csr_mul_arg_1 = Tensor([[1], [2], [3]], dtype=mstype.float32)
610    csr_mul_expect_1_1 = np.array([1, 2, 2, 2, 3, 3], dtype=np.float32)
611    csr_mul_expect_1_2 = np.array([[0], [6], [9]], dtype=np.float32)
612    csr_mul_output_1 = test_csr_mul(indptr, indices, values, dense_shape, csr_mul_arg_1)
613    assert len(csr_mul_output_1) == 4
614    assert np.allclose(csr_mul_output_1[2].asnumpy(), csr_mul_expect_1_1)
615    assert np.allclose(csr_mul_output_1[3].asnumpy(), csr_mul_expect_1_2)
616
617    csr_mul_arg_2 = Tensor(np.arange(12).reshape(3, 4), dtype=mstype.float32)
618    csr_mul_expect_2_1 = np.array([3, 4, 5, 6, 9, 11], dtype=np.float32)
619    csr_mul_expect_2_2 = np.array([[0, 0, 0, 0], [1, 2, 3, 0], [0, 4, 0, 5]], np.float32)
620    csr_mul_output_2 = test_csr_mul(indptr, indices, values, dense_shape, csr_mul_arg_2)
621    assert len(csr_mul_output_2) == 4
622    assert np.allclose(csr_mul_output_2[2].asnumpy(), csr_mul_expect_2_1)
623    assert np.allclose(csr_mul_output_2[3].asnumpy(), csr_mul_expect_2_2)
624
625    csr_div_expect_1_1 = np.array([1, 0.5, 0.5, 0.5, 0.3333333, 0.3333333], dtype=np.float32)
626    csr_div_expect_1_2 = np.array([[0], [-1.5], [-1]], dtype=np.float32)
627    csr_div_arg_1 = Tensor([[1], [2], [3]], dtype=mstype.float32)
628    csr_div_output_1 = test_csr_div(indptr, indices, values, dense_shape, csr_div_arg_1)
629    assert len(csr_div_output_1) == 4
630    assert np.allclose(csr_div_output_1[2].asnumpy(), csr_div_expect_1_1)
631    assert np.allclose(csr_div_output_1[3].asnumpy(), csr_div_expect_1_2)
632
633    csr_div_arg_2 = Tensor(np.arange(1, 13).reshape(3, 4), dtype=mstype.float32)
634    csr_div_expect_2_1 = np.array([0.25, 0.2, 0.16666667, 0.14285715, 0.1, 0.0833333], dtype=np.float32)
635    csr_div_expect_2_2 = np.array(
636        [[0, 0, 0, 0], [-0.04, -0.05555556, -0.06122449, 0], [0, -0.04, 0, -0.03472222]], dtype=np.float32)
637    csr_div_output_2 = test_csr_div(indptr, indices, values, dense_shape, csr_div_arg_2)
638    assert len(csr_div_output_2) == 4
639    assert np.allclose(csr_div_output_2[2].asnumpy(), csr_div_expect_2_1)
640    assert np.allclose(csr_div_output_2[3].asnumpy(), csr_div_expect_2_2)
641
642
643@pytest.mark.level2
644@pytest.mark.platform_x86_gpu_training
645@pytest.mark.platform_x86_cpu
646@pytest.mark.env_onecard
647def test_csr_method():
648    """
649    Feature: Test csr tensor methods.
650    Description: Test csr_tensor.to_coo(), csr_tensor.to_dense().
651    Expectation: Success.
652    """
653    if get_platform() != "linux":
654        return
655
656    class CSRToCOONet(nn.Cell):
657        def construct(self, csr_tensor):
658            return csr_tensor.to_coo()
659
660    class CSRToDenseNet(nn.Cell):
661        def construct(self, csr_tensor):
662            return csr_tensor.to_dense()
663
664    indptr = Tensor([0, 1, 4, 6], dtype=mstype.int32)
665    indices = Tensor([3, 0, 1, 2, 1, 3], dtype=mstype.int32)
666    values = Tensor(np.arange(6), dtype=mstype.float32)
667    dense_shape = (3, 4)
668    csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
669
670    to_coo_output = CSRToCOONet()(csr_tensor)
671    to_coo_expect_1 = np.array([[0, 3], [1, 0], [1, 1], [1, 2], [2, 1], [2, 3]], dtype=np.int32)
672    to_coo_expect_2 = np.arange(6).astype(np.float32)
673    assert np.allclose(to_coo_output.indices.asnumpy(), to_coo_expect_1)
674    assert np.allclose(to_coo_output.values.asnumpy(), to_coo_expect_2)
675
676    to_dense_output = CSRToDenseNet()(csr_tensor)
677    to_dense_expect = np.array([[0, 0, 0, 0], [1, 2, 3, 0], [0, 4, 0, 5]], np.float32)
678    assert np.allclose(to_dense_output.asnumpy(), to_dense_expect)
679
680
681@pytest.mark.level2
682@pytest.mark.platform_x86_gpu_training
683@pytest.mark.platform_x86_cpu
684@pytest.mark.env_onecard
685def test_bprop2():
686    """
687    Feature: Test back-propagation with CSR-related Ops.
688    Description: Test back-propagation of make_csr, csr.attributes, csr.methods().
689    Expectation: Success.
690    """
691    if get_platform() != "linux":
692        return
693    grad_op = ops.GradOperation(get_all=True)
694    indptr = Tensor([0, 1, 4, 6], dtype=mstype.int32)
695    indices = Tensor([3, 0, 1, 2, 1, 3], dtype=mstype.int32)
696    values = Tensor(np.arange(6) - 3.5, dtype=mstype.float32)
697    dense_shape = (3, 4)
698
699    @grad_op
700    @jit
701    def test_csr_tensor(indptr, indices, values, dense_shape):
702        csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
703        return csr_tensor
704
705    @grad_op
706    @jit
707    def test_csr_indptr(indptr, indices, values, dense_shape):
708        csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
709        return csr_tensor.indptr
710
711    @grad_op
712    @jit
713    def test_csr_indices(indptr, indices, values, dense_shape):
714        csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
715        return csr_tensor.indices
716
717    @grad_op
718    @jit
719    def test_csr_values(indptr, indices, values, dense_shape):
720        csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
721        return csr_tensor.values
722
723    @grad_op
724    @jit
725    def test_csr_shape(indptr, indices, values, dense_shape):
726        csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
727        return csr_tensor.shape
728
729    @grad_op
730    @jit
731    def test_csr_cast(indptr, indices, values, dense_shape):
732        csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
733        return csr_tensor.astype(mstype.int32)
734
735    @grad_op
736    @jit
737    def test_csr_dtype(indptr, indices, values, dense_shape):
738        csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
739        return csr_tensor.dtype
740
741    @grad_op
742    @jit
743    def test_csr_to_tuple(indptr, indices, values, dense_shape):
744        csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
745        return csr_tensor.to_tuple()
746
747    @grad_op
748    @jit
749    def test_csr_to_abs(indptr, indices, values, dense_shape):
750        csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
751        return csr_tensor.abs()
752
753    @grad_op
754    @jit
755    def test_csr_to_coo(indptr, indices, values, dense_shape):
756        csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
757        return csr_tensor.to_coo()
758
759    @grad_op
760    @jit
761    def test_csr_to_dense(indptr, indices, values, dense_shape):
762        csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
763        return csr_tensor.to_dense()
764
765    all_zero = (np.zeros(indptr.shape, np.int32), np.zeros(indices.shape, np.int32), np.zeros(values.shape, np.float32))
766    values_on = (np.zeros(indptr.shape, np.int32), np.zeros(indices.shape, np.int32), np.ones(values.shape, np.float32))
767    values_absgrad = (np.zeros(indptr.shape, np.int32), np.zeros(indices.shape, np.int32), np.sign(values.asnumpy()))
768    compare_res(test_csr_tensor(indptr, indices, values, dense_shape), values_on)
769    compare_res(test_csr_indptr(indptr, indices, values, dense_shape), all_zero)
770    compare_res(test_csr_indices(indptr, indices, values, dense_shape), all_zero)
771    compare_res(test_csr_values(indptr, indices, values, dense_shape), values_on)
772    compare_res(test_csr_cast(indptr, indices, values, dense_shape), values_on)
773    compare_res(test_csr_shape(indptr, indices, values, dense_shape), all_zero)
774    compare_res(test_csr_dtype(indptr, indices, values, dense_shape), all_zero)
775    compare_res(test_csr_to_tuple(indptr, indices, values, dense_shape), values_on)
776    compare_res(test_csr_to_abs(indptr, indices, values, dense_shape), values_absgrad)
777    compare_res(test_csr_to_coo(indptr, indices, values, dense_shape), values_on)
778    compare_res(test_csr_to_dense(indptr, indices, values, dense_shape), values_on)
779
780
781@pytest.mark.level1
782@pytest.mark.platform_x86_gpu_training
783@pytest.mark.env_onecard
784def test_dense_to_csr():
785    """
786    Feature: Test dense tensor to csr methods.
787    Description: Test tensor.to_csr().
788    Expectation: Success.
789    """
790    dense_tensor = Tensor([[0, 1, 2, 0], [0, 0, 0, 0], [1, 0, 0, 0]], dtype=mstype.float32)
791    grad_op = ops.GradOperation(get_all=True, sens_param=True)
792    def test_to_csr(dense_tensor):
793        return dense_tensor.to_csr()
794
795    csr_tensor = test_to_csr(dense_tensor)
796    csr_tensor_graph = jit(test_to_csr)(dense_tensor)
797    expect = CSRTensor(Tensor([0, 2, 2, 3], dtype=mstype.int32),
798                       Tensor([1, 2, 0], dtype=mstype.int32),
799                       Tensor([1, 2, 1], dtype=mstype.float32),
800                       (3, 4))
801    assert isinstance(csr_tensor, CSRTensor)
802    assert isinstance(csr_tensor_graph, CSRTensor)
803    compare_csr(csr_tensor, expect)
804    compare_csr(csr_tensor_graph, expect)
805
806    dense_tensor_grad = grad_op(test_to_csr)(dense_tensor, expect)
807    assert (dense_tensor_grad[0].asnumpy() == np.array([[0, 1, 2, 0], [0, 0, 0, 0], [1, 0, 0, 0]])).all()
808
809
810@pytest.mark.level1
811@pytest.mark.platform_x86_gpu_training
812@pytest.mark.platform_x86_cpu
813@pytest.mark.env_onecard
814def test_csr_magic_methods():
815    """
816    Feature: Test csr magic methods.
817    Description: Test CSRTensor.__neg__, CSRTensor.__add__, CSRTensor.__sub__.
818    Expectation: Success.
819    """
820    if get_platform() != "linux":
821        return
822    indptr = Tensor([0, 1, 4, 6], dtype=mstype.int32)
823    indices = Tensor([3, 0, 1, 2, 1, 3], dtype=mstype.int32)
824    values = Tensor(np.arange(6) - 3.5, dtype=mstype.float32)
825    shape = (3, 4)
826
827    indptr_2 = Tensor([0, 2, 3, 4], dtype=mstype.int32)
828    indices_2 = Tensor([2, 3, 0, 1], dtype=mstype.int32)
829    values_2 = Tensor(np.arange(4) - 2.5, dtype=mstype.float32)
830
831    def test_csr_neg(indptr, indices, values, shape):
832        csr_tensor = CSRTensor(indptr, indices, values, shape)
833        return -csr_tensor
834
835    def test_csr_add(indptr, indptr_2, indices, indices_2, values, values_2, shape):
836        csr_tensor_1 = CSRTensor(indptr, indices, values, shape)
837        csr_tensor_2 = CSRTensor(indptr_2, indices_2, values_2, shape)
838        return csr_tensor_1 + csr_tensor_2
839
840    def test_csr_sub(indptr, indptr_2, indices, indices_2, values, values_2, shape):
841        csr_tensor_1 = CSRTensor(indptr, indices, values, shape)
842        csr_tensor_2 = CSRTensor(indptr_2, indices_2, values_2, shape)
843        return csr_tensor_1 - csr_tensor_2
844
845    neg_expect = CSRTensor(indptr, indices, Tensor([3.5, 2.5, 1.5, 0.5, -0.5, -1.5], mstype.float32), shape)
846    neg_output = test_csr_neg(indptr, indices, values, shape)
847    compare_csr(neg_output, neg_expect)
848    neg_output = jit(test_csr_neg)(indptr, indices, values, shape)
849    compare_csr(neg_output, neg_expect)
850
851    add_expect = CSRTensor(Tensor([0, 2, 5, 7], mstype.int32), Tensor([2, 3, 0, 1, 2, 1, 3], mstype.int32),
852                           Tensor([-2.5, -5, -3, -1.5, -0.5, 1, 1.5], mstype.float32), shape)
853    add_output = test_csr_add(indptr, indptr_2, indices, indices_2, values, values_2, shape)
854    compare_csr(add_output, add_expect)
855    add_output = jit(test_csr_add)(indptr, indptr_2, indices, indices_2, values, values_2, shape)
856    compare_csr(add_output, add_expect)
857
858    sub_expect = CSRTensor(Tensor([0, 2, 5, 7], mstype.int32), Tensor([2, 3, 0, 1, 2, 1, 3], mstype.int32),
859                           Tensor([2.5, -2, -2, -1.5, -0.5, 0, 1.5], mstype.float32), shape)
860    sub_output = test_csr_sub(indptr, indptr_2, indices, indices_2, values, values_2, shape)
861    compare_csr(sub_output, sub_expect)
862
863    sub_output = jit(test_csr_sub)(indptr, indptr_2, indices, indices_2, values, values_2, shape)
864    compare_csr(sub_output, sub_expect)
865
866
867@pytest.mark.level1
868@pytest.mark.platform_x86_gpu_training
869@pytest.mark.platform_x86_cpu
870@pytest.mark.env_onecard
871def test_csr_add_dynamic_shape_methods():
872    """
873    Feature: Test csr add dynamic shape methods.
874    Description: Test csr_add.
875    Expectation: Success.
876    """
877    if get_platform() != "linux":
878        return
879
880    class Net(nn.Cell):
881        def construct(self, x, y, z):
882            return -x + y + z
883
884    indptr = Tensor([0, 1, 2, 4, 5], dtype=mstype.int32)
885    indices = Tensor([4, 4, 1, 2, 2], dtype=mstype.int32)
886    shape = (4, 5)
887    values = Tensor(np.arange(5) - 2.5, dtype=mstype.float32)
888
889    def test_csr_add(indptr, indices, values, shape):
890        x = CSRTensor(indptr, indices, values, shape)
891        net = Net()
892        return net(x, x, x)
893
894    add_expect = CSRTensor(indptr, indices, Tensor(
895        [-2.5, -1.5, -0.5, 0.5, 1.5], mstype.float32), shape)
896    add_output = test_csr_add(indptr, indices, values, shape)
897    compare_csr(add_output, add_expect)
898    add_output = jit(test_csr_add)(indptr, indices, values, shape)
899    compare_csr(add_output, add_expect)
900