1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Generate vm_impl function for array ops""" 16import numpy as np 17import mindspore.common.dtype as mstype 18from mindspore.common.tensor import Tensor 19from mindspore.ops import operations as P 20from mindspore.ops.operations import _grad_ops as G 21from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters 22from .vm_interface import vm 23 24# pylint: disable=unused-argument 25@vm_impl_getters.register(P.Assign) 26def vm_impl_assign(self): 27 """Generate vm_impl function for Assign""" 28 def vm_impl(x, value, u=None): 29 x.assign_value(value) 30 return x 31 return vm_impl 32 33@vm_impl_getters.register(P.ExpandDims) 34def vm_impl_expand_dims(self): 35 """Generate vm_impl function for ExpandDims""" 36 37 def vm_impl(x, axis): 38 if isinstance(x, float): 39 x = Tensor(np.array([x])) 40 x = x.asnumpy() 41 out = vm.expand_dims(x, axis) 42 return Tensor(out) 43 44 return vm_impl 45 46 47@vm_impl_getters.register(P.DType) 48def vm_impl_dType(self): 49 """Generate vm_impl function for DType""" 50 51 def vm_impl(x): 52 # update the src type 53 return x.dtype 54 55 return vm_impl 56 57 58@vm_impl_getters.register(P.Cast) 59def vm_impl_cast(self): 60 """Generate vm_impl function for Cast""" 61 62 def vm_impl(x, t): 63 if isinstance(t, type(mstype.tensor)): 64 t = t.element_type() 65 # update the src type 66 x = x.asnumpy() 67 out = x.astype(mstype.dtype_to_nptype(t)) 68 return Tensor(out) 69 70 return vm_impl 71 72 73@vm_impl_getters.register(P.Reshape) 74def vm_impl_reshape(self): 75 """Generate vm_impl function for Reshape""" 76 77 def vm_impl(x, shp): 78 x = x.asnumpy() 79 out = vm.reshape(x, shp) 80 return Tensor(out) 81 82 return vm_impl 83 84 85@vm_impl_getters.register(P.Shape) 86def vm_impl_shape(self): 87 """Generate vm_impl function for Shape""" 88 89 def vm_impl(x): 90 shp = vm.shape(x.asnumpy()) 91 return shp 92 93 return vm_impl 94 95 96@vm_impl_getters.register(P.Squeeze) 97def vm_impl_squeeze(self): 98 """Generate vm_impl function for Squeeze""" 99 100 def vm_impl(x): 101 x = x.asnumpy() 102 out = vm.squeeze(x, self.axis) 103 return Tensor(out) 104 105 return vm_impl 106 107 108@vm_impl_getters.register(P.Transpose) 109def vm_impl_transpose(self): 110 """Generate vm_impl function for Transpose""" 111 112 def vm_impl(x, perm=None): 113 x = x.asnumpy() 114 if perm is None: 115 perm = [i for i in reversed(range(len(x.shape)))] 116 out = vm.transpose(x, perm) 117 return Tensor(out) 118 119 return vm_impl 120 121 122@vm_impl_getters.register(P.Split) 123def vm_impl_split(self): 124 """Generate vm_impl function for Split""" 125 126 def vm_impl(x): 127 x = x.asnumpy() 128 output = np.array_split(x, (self.pos,)) 129 return Tensor(output[0]), Tensor(output[1]) 130 131 return vm_impl 132 133 134@vm_impl_getters.register(P.Fill) 135def vm_impl_fill(self): 136 """Generate vm_impl function for Fill""" 137 138 def vm_impl(dims, x): 139 if isinstance(x, int): 140 ret = np.full(dims, x, np.int32) 141 else: 142 ret = np.full(dims, x, np.float32) 143 return Tensor(ret) 144 145 return vm_impl 146 147 148@vm_impl_getters.register(P.Eye) 149def vm_impl_eye(self): 150 """Generate vm_impl function for Eye""" 151 152 def vm_impl(n, m, t): 153 np_type = mstype.dtype_to_nptype(t) 154 ret = np.eye(n, m, dtype=np_type) 155 return Tensor(ret) 156 157 return vm_impl 158 159 160@vm_impl_getters.register(P.InvertPermutation) 161def vm_impl_invert_permutation(self): 162 """Generate vm_impl function for InvertPermutation""" 163 164 def vm_impl(x): 165 out = vm.invert_permutation(x) 166 return out 167 168 return vm_impl 169 170 171@vm_impl_getters.register(P.Argmax) 172def vm_impl_argmax(self): 173 """Generate vm_impl function for Argmax""" 174 175 def vm_impl(x): 176 output = np.argmax(x.asnumpy(), axis=self.axis) 177 return Tensor(output.ravel()) 178 179 return vm_impl 180 181 182@vm_impl_getters.register(P.Tile) 183def vm_impl_tile(self): 184 """Generate vm_impl function for Tile""" 185 186 def vm_impl(x, multiples): 187 x = x.asnumpy() 188 out = np.tile(x, multiples) 189 return Tensor(out) 190 191 return vm_impl 192 193 194@vm_impl_getters.register(P.ReduceAll) 195def vm_impl_all(self): 196 """Generate vm_impl function for All""" 197 198 def vm_impl(x, axis): 199 x = x.asnumpy() 200 out = vm.all(x, axis, self.keep_dims) 201 return Tensor(out) 202 203 return vm_impl 204 205 206@vm_impl_getters.register(P.ReduceAny) 207def vm_impl_any(self): 208 """Generate vm_impl function for Any""" 209 210 def vm_impl(x, axis): 211 x = x.asnumpy() 212 out = vm.any(x, axis, self.keep_dims) 213 return Tensor(out) 214 215 return vm_impl 216 217 218@vm_impl_getters.register(P.Concat) 219def vm_impl_concatV2(self): 220 """Generate vm_impl function for Concat""" 221 222 def vm_impl(x): 223 x = x.asnumpy() 224 out = vm.Concat(x, self.axis) 225 return Tensor(out) 226 227 return vm_impl 228 229 230@vm_impl_getters.register(P.Slice) 231def vm_impl_slice(self): 232 """Generate vm_impl function for Slice""" 233 234 def vm_impl(x, begin, size): 235 x = x.asnumpy() 236 begin = begin.asnumpy() 237 size = size.asnumpy() 238 out = vm.Slice(x, begin, size) 239 return Tensor(out) 240 241 return vm_impl 242 243 244@vm_impl_getters.register(G.ConcatOffset) 245def vm_impl_concatOffset(self): 246 """Generate vm_impl function for ConcatOffset""" 247 248 def vm_impl(x): 249 out = vm.ConcatOffset(x) # out is tuple 250 return out 251 252 return vm_impl 253 254 255@vm_impl_getters.register(P.ReduceSum) 256def vm_impl_sum(self): 257 """Generate vm_impl function for Sum""" 258 259 def vm_impl(x, axis): 260 x = x.asnumpy() 261 if axis == (): 262 out = np.sum(x) 263 else: 264 out = np.sum(x, axis=axis) 265 return Tensor(np.array(out)) 266 267 return vm_impl 268 269 270@vm_impl_getters.register(P.Select) 271def vm_impl_select(self): 272 """Generate vm_impl function for Select""" 273 274 def vm_impl(cond, x, y): 275 """ 276 Args: 277 cond: A `Tensor` of type `bool` 278 x: A Tensor which may have the same shape as `condition`. 279 y: A `Tensor` with the same shape and type as `x`. 280 """ 281 cond = cond.asnumpy() 282 x = x.asnumpy() 283 y = y.asnumpy() 284 out = vm.select(cond, x, y) 285 return Tensor(out) 286 287 return vm_impl 288 289 290@vm_impl_getters.register(P.Square) 291def vm_impl_square(self): 292 """Generate vm_impl function for Square""" 293 294 def vm_impl(x): 295 x = x.asnumpy() 296 return Tensor(x * x) 297 298 return vm_impl 299 300 301@vm_impl_getters.register(P.ZerosLike) 302def vm_impl_zeros_like(self): 303 """Generate vm_impl function for ZerosLike""" 304 def vm_impl(x): 305 return Tensor(np.zeros_like(x.asnumpy())) 306 307 308@vm_impl_getters.register(P.Partial) 309def vm_impl_partial(self): 310 """Generate vm_impl function for Partial""" 311 def vm_impl(*args): 312 func = args[0].__call__ 313 partial_func = functools.partial(func, *args[1:]) 314 return partial_func 315 316 return vm_impl 317 318 319@vm_impl_getters.register(P.Depend) 320def vm_impl_depend(self): 321 """Generate vm_impl function for Depend""" 322 def vm_impl(value, expr): 323 return value 324 325 return vm_impl 326 327 328@vm_impl_getters.register(P.UpdateState) 329def vm_impl_updatestate(self): 330 """Generate vm_impl function for UpdateState""" 331 def vm_impl(monad, expr): 332 return monad 333 334 return vm_impl 335 336 337@vm_impl_getters.register(P.Load) 338def vm_impl_load(self): 339 """Generate vm_impl function for Load""" 340 def vm_impl(value, u=None): 341 return value 342 343 return vm_impl 344