1# Copyright 2020-2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Stub Tensor implementation.""" 16 17import inspect 18from functools import reduce 19from mindspore.common.tensor import Tensor 20from mindspore.common.dtype import type_size_in_bytes 21from mindspore._c_expression import TensorNode, SequenceNode, NoneTypeNode, AnyTypeNode 22from mindspore._c_expression import Tensor as Tensor_ 23from mindspore.common.api import _convert_python_data 24 25 26def _stub_member(var, init): 27 """handle stub tensor's member, use a member cache to improve performance""" 28 def getx(stub): 29 if stub.tensor is not None: 30 return getattr(stub.tensor, var) 31 if hasattr(stub, "member_cache"): 32 return stub.member_cache.get(var, init) 33 return init 34 35 def setx(stub, value): 36 if stub.tensor is not None: 37 setattr(stub.tensor, var, value) 38 else: 39 if not hasattr(stub, "member_cache"): 40 stub.member_cache = {} 41 stub.member_cache[var] = value 42 return property(getx, setx) 43 44 45def _stub_method(method): 46 def fun(*arg, **kwargs): 47 stub = arg[0] 48 arg = (stub.stub_sync(),) + arg[1:] 49 return method(*arg, **kwargs) 50 return fun 51 52 53class StubTensor: 54 """stub tensor for async op run.""" 55 const_arg = _stub_member("const_arg", None) 56 init = _stub_member("init", None) 57 init_finished = _stub_member("init_finished", False) 58 virtual_flag = _stub_member("virtual_flag", False) 59 adapter_flag = _stub_member("adapter_flag", False) 60 parent_tensor_ = _stub_member("parent_tensor_", None) 61 index_of_parent_ = _stub_member("index_of_parent_", None) 62 slice_num_of_persistent_data_ = _stub_member("slice_num_of_persistent_data_", None) 63 slice_shape_of_persistent_data_ = _stub_member("slice_shape_of_persistent_data_", None) 64 # auto gradient information 65 _grad = _stub_member("_grad", None) 66 _grad_fn = _stub_member("_grad_fn", None) 67 _requires_grad = _stub_member("_requires_grad", False) 68 _retain_grad = _stub_member("_retain_grad", False) 69 70 def __init__(self, stub=None, tensor=None): 71 self.stub = stub 72 self.tensor = tensor 73 74 __str__ = _stub_method(Tensor.__str__) 75 __repr__ = _stub_method(Tensor.__repr__) 76 __setitem__ = _stub_method(Tensor.__setitem__) 77 78 __lt__ = Tensor.__lt__ 79 __le__ = Tensor.__le__ 80 __gt__ = Tensor.__gt__ 81 __ge__ = Tensor.__ge__ 82 __eq__ = Tensor.__eq__ 83 __ne__ = Tensor.__ne__ 84 85 @property 86 def shape(self): 87 """shape stub.""" 88 if self.stub: 89 if not hasattr(self, "stub_shape"): 90 self.stub_shape = self.stub.get_shape() 91 return self.stub_shape 92 return self.tensor.shape 93 94 @property 95 def dtype(self): 96 """dtype stub.""" 97 if self.stub: 98 if not hasattr(self, "stub_dtype"): 99 self.stub_dtype = self.stub.get_dtype() 100 return self.stub_dtype 101 return self.tensor.dtype 102 103 @property 104 def size(self): 105 """size stub.""" 106 shape = self.shape 107 return reduce((lambda x, y: x * y), shape) if shape else 1 108 109 @property 110 def itemsize(self): 111 """itemsize stub.""" 112 return type_size_in_bytes(self.dtype) 113 114 @property 115 def nbytes(self): 116 """nbytes stub.""" 117 return self.size * self.itemsize 118 119 @property 120 def ndim(self): 121 """ndim stub.""" 122 return len(self.shape) 123 124 @property 125 def strides(self): 126 """strides stub.""" 127 return self.stub_sync().strides 128 129 @property 130 def has_init(self): 131 """has_init stub.""" 132 return False 133 134 def ndimension(self): 135 r""" 136 Alias for :func:`mindspore.Tensor.ndim`. 137 """ 138 return self.ndim 139 140 def dim(self): 141 r""" 142 Alias for :func:`mindspore.Tensor.ndim`. 143 """ 144 return self.ndim 145 146 asnumpy = _stub_method(Tensor.asnumpy) 147 is_persistent_data = _stub_method(Tensor.is_persistent_data) 148 asnumpy_of_slice_persistent_data = _stub_method(Tensor.asnumpy_of_slice_persistent_data) 149 slice_num_of_persistent_data = _stub_method(Tensor.slice_num_of_persistent_data) 150 slice_shape_of_persistent_data = _stub_method(Tensor.slice_shape_of_persistent_data) 151 flush_from_cache = _stub_method(Tensor.flush_from_cache) 152 contiguous = _stub_method(Tensor.contiguous) 153 is_contiguous = _stub_method(Tensor.is_contiguous) 154 register_hook = _stub_method(Tensor.register_hook) 155 156 def stub_sync(self): 157 """sync real tensor.""" 158 if self.stub: 159 val = self.stub.get_value() 160 self.tensor = Tensor(val, internal=True) 161 if hasattr(self, "member_cache"): 162 for k, v in self.member_cache.items(): 163 setattr(self.tensor, k, v) 164 self.stub = None 165 return self.tensor 166 167 def __getstate__(self): 168 state = {} 169 value = self.stub.get_value() if self.stub else self.tensor.__getstate__() 170 state["value"] = value 171 return state 172 173 def __setstate__(self, state): 174 value = state.pop("value") 175 self.stub = None 176 self.tensor = Tensor(value, internal=True) 177 178 179def _init_stub_tensor_api(): 180 """adapt to python tensor and cpp tensor api""" 181 need_init_func = set(dir(Tensor)) - set(dir(StubTensor)) 182 cpp_tensor_func = dir(Tensor_) 183 for attr in need_init_func: 184 func = inspect.getattr_static(Tensor, attr) 185 if attr in cpp_tensor_func: 186 # for cpp tensor api, we always need to sync for real tensor first 187 setattr(StubTensor, attr, _stub_method(func)) 188 else: 189 setattr(StubTensor, attr, func) 190 191 192_init_stub_tensor_api() 193 194 195def _convert_stub(stub): 196 "convert stub to StubNode or Value" 197 if isinstance(stub, TensorNode): 198 return StubTensor(stub) 199 if isinstance(stub, tuple): 200 return tuple(_convert_stub(e) for e in stub) 201 if isinstance(stub, SequenceNode): 202 elements = stub.get_elements() 203 return tuple(_convert_stub(e) for e in elements) 204 if isinstance(stub, NoneTypeNode): 205 val = stub.get_real_value() 206 return _convert_python_data(val) 207 if isinstance(stub, AnyTypeNode): 208 val = stub.get_real_node() 209 return _convert_stub(val) 210 return _convert_python_data(stub) 211