• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16@File  : test_sparse_pynative.py
17@Author:
18@Date  : 2020-08-04
19@Desc  : test mindspore sparse pynative
20"""
21import pytest
22import mindspore as ms
23import mindspore.nn as nn
24from mindspore import context, Tensor, RowTensor, SparseTensor
25from mindspore.ops import composite as C
26
27@pytest.fixture(scope="module", autouse=True)
28def setup_teardown():
29    context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=True)
30    yield
31    context.set_context(mode=context.GRAPH_MODE, enable_sparse=False)
32
33
34grad_all = C.GradOperation(get_all=True)
35class GradWrap(nn.Cell):
36    def __init__(self, network):
37        super(GradWrap, self).__init__()
38        self.network = network
39    def construct(self, *args):
40        grad = grad_all(self.network)(*args)
41        return grad
42
43
44def test_row_tensor_attr():
45    class RowTensorGetAttr(nn.Cell):
46        def __init__(self, dense_shape):
47            super(RowTensorGetAttr, self).__init__()
48            self.dense_shape = dense_shape
49        def construct(self, indices, values):
50            x = RowTensor(indices, values, self.dense_shape)
51            return x.values, x.indices, x.dense_shape
52    indices = Tensor([0])
53    values = Tensor([[1, 2]], dtype=ms.float32)
54    RowTensorGetAttr((3, 2))(indices, values)
55    GradWrap(RowTensorGetAttr((3, 2)))(indices, values)
56
57
58def test_sparse_tensor_attr():
59    class SparseTensorGetAttr(nn.Cell):
60        def __init__(self):
61            super(SparseTensorGetAttr, self).__init__()
62            self.dense_shape = (3, 4)
63        def construct(self, indices, values):
64            x = SparseTensor(indices, values, self.dense_shape)
65            return x.values, x.indices, x.dense_shape
66
67    indices = Tensor([[0, 1], [1, 2]])
68    values = Tensor([1, 2], dtype=ms.float32)
69    SparseTensorGetAttr()(indices, values)
70    GradWrap(SparseTensorGetAttr())(indices, values)
71