1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Generate vm_impl function for math ops""" 16import copy 17import numpy as np 18 19from mindspore.common.dtype import dtype_to_nptype 20from mindspore.common.tensor import Tensor 21from mindspore.ops import operations as P 22from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters 23from .vm_interface import vm 24 25 26# pylint: disable=unused-argument 27 28 29@vm_impl_getters.register(P.Add) 30def vm_impl_tensor_add(self): 31 """Generate vm_impl function for TensorAdd.""" 32 33 def vm_impl(x, y): 34 x = x.asnumpy() 35 y = y.asnumpy() 36 return Tensor(x + y) 37 38 return vm_impl 39 40 41# pylint: disable=used-before-assignment 42@vm_impl_getters.register(P.LogicalNot) 43def vm_impl_logical_not(self): 44 def vm_impl(x): 45 x = x.asnumpy() 46 out = vm.logical_not(x) 47 return Tensor(out) 48 49 return vm_impl 50 51@vm_impl_getters.register(P.MatMul) 52def vm_impl_mat_mul(self): 53 """Generate vm_impl function for MatMul.""" 54 55 def vm_impl(x, w): 56 x = x.asnumpy() 57 w = w.asnumpy() 58 if self.transpose_a: 59 x = x.transpose() 60 if self.transpose_b: 61 w = w.transpose() 62 z = x @ w 63 return Tensor(z) 64 65 return vm_impl 66 67 68@vm_impl_getters.register(P.AddN) 69def vm_impl_addn(self): 70 """Generate vm_impl function for AddN.""" 71 72 def vm_impl(inputs): 73 added = copy.deepcopy(inputs[0].asnumpy()) 74 for x in inputs[1:]: 75 added += x.asnumpy() 76 return Tensor(added) 77 78 return vm_impl 79 80 81@vm_impl_getters.register(P.Neg) 82def vm_impl_neg(self): 83 """Generate vm_impl function for Neg.""" 84 85 def vm_impl(x): 86 x = x.asnumpy() 87 return Tensor(-x) 88 89 return vm_impl 90 91 92@vm_impl_getters.register(P.Sub) 93def vm_impl_Sub(self): 94 """Generate vm_impl function for Sub.""" 95 96 def vm_impl(x, y): 97 x = x.asnumpy() 98 y = y.asnumpy() 99 return Tensor(x - y) 100 101 return vm_impl 102 103 104@vm_impl_getters.register(P.Mul) 105def vm_impl_mul(self): 106 """Generate vm_impl function for Mul.""" 107 108 def vm_impl(x, y): 109 x = x.asnumpy() 110 y = y.asnumpy() 111 return Tensor(x * y) 112 113 return vm_impl 114 115 116@vm_impl_getters.register(P.Square) 117def vm_impl_square(self): 118 """Generate vm_impl function for Square.""" 119 120 def vm_impl(x): 121 x = x.asnumpy() 122 return Tensor(x * x) 123 124 return vm_impl 125 126 127@vm_impl_getters.register(P.Sqrt) 128def vm_impl_sqrt(self): 129 """Generate vm_impl function for Sqrt.""" 130 131 def vm_impl(x): 132 x = x.asnumpy() 133 res = vm.sqrt(x) 134 return Tensor(res) 135 136 return vm_impl 137 138 139@vm_impl_getters.register(P.Pow) 140def vm_impl_pow(self): 141 """Generate vm_impl function for Pow.""" 142 143 def vm_impl(x, y): 144 x = x.asnumpy() 145 y = y.asnumpy() 146 res = vm.power(x, y) 147 return Tensor(res) 148 149 return vm_impl 150 151 152@vm_impl_getters.register(P.Exp) 153def vm_impl_exp(self): 154 """Generate vm_impl function for Exp.""" 155 156 def vm_impl(x): 157 x = x.asnumpy() 158 res = vm.exp(x) 159 return Tensor(res) 160 161 return vm_impl 162 163 164@vm_impl_getters.register(P.RealDiv) 165def vm_impl_real_div(self): 166 """Generate vm_impl function for RealDiv.""" 167 168 def vm_impl(x, y): 169 x = x.asnumpy() 170 y = y.asnumpy() 171 out = x / y 172 out = np.array(out, x.dtype) 173 return Tensor(out) 174 175 return vm_impl 176 177 178@vm_impl_getters.register(P.Div) 179def vm_impl_div(self): 180 """Generate vm_impl function for Div.""" 181 182 def vm_impl(x, y): 183 x = x.asnumpy() 184 y = y.asnumpy() 185 return Tensor(x / y) 186 187 return vm_impl 188 189 190@vm_impl_getters.register(P.ReduceMean) 191def vm_impl_reduce_mean(self): 192 """Generate vm_impl function for ReduceMean.""" 193 194 def vm_impl(x, axis): 195 x = x.asnumpy() 196 out = vm.mean(x, axis) 197 return Tensor(out) 198 199 return vm_impl 200 201@vm_impl_getters.register(P.ReduceMax) 202def vm_impl_reduce_max(self): 203 """Generate vm_impl function for ReduceMean.""" 204 205 def vm_impl(x, axis): 206 x = x.asnumpy() 207 if axis == (): 208 axis = None 209 out = np.amax(x, axis) 210 return Tensor(out) 211 212 return vm_impl 213 214@vm_impl_getters.register(P.Equal) 215def vm_impl_equal(self): 216 """Generate vm_impl function for Equal.""" 217 218 def vm_impl(x, y): 219 x = x.asnumpy() 220 y = y.asnumpy() 221 out = vm.equal(x, y) 222 return Tensor(np.array(out)) 223 224 return vm_impl 225 226 227@vm_impl_getters.register(P.NotEqual) 228def vm_impl_not_equal(self): 229 """Generate vm_impl function for NotEqual.""" 230 231 def vm_impl(x, y): 232 x = x.asnumpy() 233 y = y.asnumpy() 234 out = vm.not_equal(x, y) 235 return Tensor(np.array(out)) 236 237 return vm_impl 238 239 240@vm_impl_getters.register(P.Greater) 241def vm_impl_greater(self): 242 """Generate vm_impl function for Greater.""" 243 244 def vm_impl(x, y): 245 x = x.asnumpy() 246 y = y.asnumpy() 247 out = vm.greater(x, y) 248 return Tensor(np.array(out)) 249 250 return vm_impl 251 252 253@vm_impl_getters.register(P.Maximum) 254def vm_impl_maximum(self): 255 """Generate vm_impl function for Maximum.""" 256 257 def vm_impl(x, y): 258 x = x.asnumpy() 259 y = y.asnumpy() 260 out = vm.maximum(x, y) 261 return Tensor(out) 262 263 return vm_impl 264 265 266@vm_impl_getters.register(P.Minimum) 267def vm_impl_minimum(self): 268 """Generate vm_impl function for Minimum.""" 269 270 def vm_impl(x, y): 271 x = x.asnumpy() 272 y = y.asnumpy() 273 out = vm.minimum(x, y) 274 return Tensor(out) 275 276 return vm_impl 277 278 279@vm_impl_getters.register(P.Less) 280def vm_impl_less(self): 281 """Generate vm_impl function for Less""" 282 283 def vm_impl(x, y): 284 x = x.asnumpy() 285 y = y.asnumpy() 286 out = vm.less(x, y) 287 return Tensor(np.array(out)) 288 289 return vm_impl 290 291 292@vm_impl_getters.register(P.ScalarCast) 293def vm_impl_scalar_cast(self): 294 """Generate vm_impl function for ScalarCast""" 295 296 def vm_impl(x, t): 297 np_type = dtype_to_nptype(t) 298 value = np_type(x) 299 cast_value = value.item() 300 return cast_value 301 302 return vm_impl 303