• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Stub Tensor implementation."""
16
17import inspect
18from functools import reduce
19from mindspore.common.tensor import Tensor
20from mindspore.common.dtype import type_size_in_bytes
21from mindspore._c_expression import TensorNode, SequenceNode, NoneTypeNode, AnyTypeNode
22from mindspore._c_expression import Tensor as Tensor_
23from mindspore.common.api import _convert_python_data
24
25
26def _stub_member(var, init):
27    """handle stub tensor's member, use a member cache to improve performance"""
28    def getx(stub):
29        if stub.tensor is not None:
30            return getattr(stub.tensor, var)
31        if hasattr(stub, "member_cache"):
32            return stub.member_cache.get(var, init)
33        return init
34
35    def setx(stub, value):
36        if stub.tensor is not None:
37            setattr(stub.tensor, var, value)
38        else:
39            if not hasattr(stub, "member_cache"):
40                stub.member_cache = {}
41            stub.member_cache[var] = value
42    return property(getx, setx)
43
44
45def _stub_method(method):
46    def fun(*arg, **kwargs):
47        stub = arg[0]
48        arg = (stub.stub_sync(),) + arg[1:]
49        return method(*arg, **kwargs)
50    return fun
51
52
53class StubTensor:
54    """stub tensor for async op run."""
55    const_arg = _stub_member("const_arg", None)
56    init = _stub_member("init", None)
57    init_finished = _stub_member("init_finished", False)
58    virtual_flag = _stub_member("virtual_flag", False)
59    adapter_flag = _stub_member("adapter_flag", False)
60    parent_tensor_ = _stub_member("parent_tensor_", None)
61    index_of_parent_ = _stub_member("index_of_parent_", None)
62    slice_num_of_persistent_data_ = _stub_member("slice_num_of_persistent_data_", None)
63    slice_shape_of_persistent_data_ = _stub_member("slice_shape_of_persistent_data_", None)
64    # auto gradient information
65    _grad = _stub_member("_grad", None)
66    _grad_fn = _stub_member("_grad_fn", None)
67    _requires_grad = _stub_member("_requires_grad", False)
68    _retain_grad = _stub_member("_retain_grad", False)
69
70    def __init__(self, stub=None, tensor=None):
71        self.stub = stub
72        self.tensor = tensor
73
74    __str__ = _stub_method(Tensor.__str__)
75    __repr__ = _stub_method(Tensor.__repr__)
76    __setitem__ = _stub_method(Tensor.__setitem__)
77
78    __lt__ = Tensor.__lt__
79    __le__ = Tensor.__le__
80    __gt__ = Tensor.__gt__
81    __ge__ = Tensor.__ge__
82    __eq__ = Tensor.__eq__
83    __ne__ = Tensor.__ne__
84
85    @property
86    def shape(self):
87        """shape stub."""
88        if self.stub:
89            if not hasattr(self, "stub_shape"):
90                self.stub_shape = self.stub.get_shape()
91            return self.stub_shape
92        return self.tensor.shape
93
94    @property
95    def dtype(self):
96        """dtype stub."""
97        if self.stub:
98            if not hasattr(self, "stub_dtype"):
99                self.stub_dtype = self.stub.get_dtype()
100            return self.stub_dtype
101        return self.tensor.dtype
102
103    @property
104    def size(self):
105        """size stub."""
106        shape = self.shape
107        return reduce((lambda x, y: x * y), shape) if shape else 1
108
109    @property
110    def itemsize(self):
111        """itemsize stub."""
112        return type_size_in_bytes(self.dtype)
113
114    @property
115    def nbytes(self):
116        """nbytes stub."""
117        return self.size * self.itemsize
118
119    @property
120    def ndim(self):
121        """ndim stub."""
122        return len(self.shape)
123
124    @property
125    def strides(self):
126        """strides stub."""
127        return self.stub_sync().strides
128
129    @property
130    def has_init(self):
131        """has_init stub."""
132        return False
133
134    def ndimension(self):
135        r"""
136        Alias for :func:`mindspore.Tensor.ndim`.
137        """
138        return self.ndim
139
140    def dim(self):
141        r"""
142        Alias for :func:`mindspore.Tensor.ndim`.
143        """
144        return self.ndim
145
146    asnumpy = _stub_method(Tensor.asnumpy)
147    is_persistent_data = _stub_method(Tensor.is_persistent_data)
148    asnumpy_of_slice_persistent_data = _stub_method(Tensor.asnumpy_of_slice_persistent_data)
149    slice_num_of_persistent_data = _stub_method(Tensor.slice_num_of_persistent_data)
150    slice_shape_of_persistent_data = _stub_method(Tensor.slice_shape_of_persistent_data)
151    flush_from_cache = _stub_method(Tensor.flush_from_cache)
152    contiguous = _stub_method(Tensor.contiguous)
153    is_contiguous = _stub_method(Tensor.is_contiguous)
154    register_hook = _stub_method(Tensor.register_hook)
155
156    def stub_sync(self):
157        """sync real tensor."""
158        if self.stub:
159            val = self.stub.get_value()
160            self.tensor = Tensor(val, internal=True)
161            if hasattr(self, "member_cache"):
162                for k, v in self.member_cache.items():
163                    setattr(self.tensor, k, v)
164            self.stub = None
165        return self.tensor
166
167    def __getstate__(self):
168        state = {}
169        value = self.stub.get_value() if self.stub else self.tensor.__getstate__()
170        state["value"] = value
171        return state
172
173    def __setstate__(self, state):
174        value = state.pop("value")
175        self.stub = None
176        self.tensor = Tensor(value, internal=True)
177
178
179def _init_stub_tensor_api():
180    """adapt to python tensor and cpp tensor api"""
181    need_init_func = set(dir(Tensor)) - set(dir(StubTensor))
182    cpp_tensor_func = dir(Tensor_)
183    for attr in need_init_func:
184        func = inspect.getattr_static(Tensor, attr)
185        if attr in cpp_tensor_func:
186            # for cpp tensor api, we always need to sync for real tensor first
187            setattr(StubTensor, attr, _stub_method(func))
188        else:
189            setattr(StubTensor, attr, func)
190
191
192_init_stub_tensor_api()
193
194
195def _convert_stub(stub):
196    "convert stub to StubNode or Value"
197    if isinstance(stub, TensorNode):
198        return StubTensor(stub)
199    if isinstance(stub, tuple):
200        return tuple(_convert_stub(e) for e in stub)
201    if isinstance(stub, SequenceNode):
202        elements = stub.get_elements()
203        return tuple(_convert_stub(e) for e in elements)
204    if isinstance(stub, NoneTypeNode):
205        val = stub.get_real_value()
206        return _convert_python_data(val)
207    if isinstance(stub, AnyTypeNode):
208        val = stub.get_real_node()
209        return _convert_stub(val)
210    return _convert_python_data(stub)
211