# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Generate vm_impl function for math ops""" import copy import numpy as np from mindspore.common.dtype import dtype_to_nptype from mindspore.common.tensor import Tensor from mindspore.ops import operations as P from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters from .vm_interface import vm # pylint: disable=unused-argument @vm_impl_getters.register(P.Add) def vm_impl_tensor_add(self): """Generate vm_impl function for TensorAdd.""" def vm_impl(x, y): x = x.asnumpy() y = y.asnumpy() return Tensor(x + y) return vm_impl # pylint: disable=used-before-assignment @vm_impl_getters.register(P.LogicalNot) def vm_impl_logical_not(self): def vm_impl(x): x = x.asnumpy() out = vm.logical_not(x) return Tensor(out) return vm_impl @vm_impl_getters.register(P.MatMul) def vm_impl_mat_mul(self): """Generate vm_impl function for MatMul.""" def vm_impl(x, w): x = x.asnumpy() w = w.asnumpy() if self.transpose_a: x = x.transpose() if self.transpose_b: w = w.transpose() z = x @ w return Tensor(z) return vm_impl @vm_impl_getters.register(P.AddN) def vm_impl_addn(self): """Generate vm_impl function for AddN.""" def vm_impl(inputs): added = copy.deepcopy(inputs[0].asnumpy()) for x in inputs[1:]: added += x.asnumpy() return Tensor(added) return vm_impl @vm_impl_getters.register(P.Neg) def vm_impl_neg(self): """Generate vm_impl function for Neg.""" def vm_impl(x): x = x.asnumpy() return Tensor(-x) return vm_impl @vm_impl_getters.register(P.Sub) def vm_impl_Sub(self): """Generate vm_impl function for Sub.""" def vm_impl(x, y): x = x.asnumpy() y = y.asnumpy() return Tensor(x - y) return vm_impl @vm_impl_getters.register(P.Mul) def vm_impl_mul(self): """Generate vm_impl function for Mul.""" def vm_impl(x, y): x = x.asnumpy() y = y.asnumpy() return Tensor(x * y) return vm_impl @vm_impl_getters.register(P.Square) def vm_impl_square(self): """Generate vm_impl function for Square.""" def vm_impl(x): x = x.asnumpy() return Tensor(x * x) return vm_impl @vm_impl_getters.register(P.Sqrt) def vm_impl_sqrt(self): """Generate vm_impl function for Sqrt.""" def vm_impl(x): x = x.asnumpy() res = vm.sqrt(x) return Tensor(res) return vm_impl @vm_impl_getters.register(P.Pow) def vm_impl_pow(self): """Generate vm_impl function for Pow.""" def vm_impl(x, y): x = x.asnumpy() y = y.asnumpy() res = vm.power(x, y) return Tensor(res) return vm_impl @vm_impl_getters.register(P.Exp) def vm_impl_exp(self): """Generate vm_impl function for Exp.""" def vm_impl(x): x = x.asnumpy() res = vm.exp(x) return Tensor(res) return vm_impl @vm_impl_getters.register(P.RealDiv) def vm_impl_real_div(self): """Generate vm_impl function for RealDiv.""" def vm_impl(x, y): x = x.asnumpy() y = y.asnumpy() out = x / y out = np.array(out, x.dtype) return Tensor(out) return vm_impl @vm_impl_getters.register(P.Div) def vm_impl_div(self): """Generate vm_impl function for Div.""" def vm_impl(x, y): x = x.asnumpy() y = y.asnumpy() return Tensor(x / y) return vm_impl @vm_impl_getters.register(P.ReduceMean) def vm_impl_reduce_mean(self): """Generate vm_impl function for ReduceMean.""" def vm_impl(x, axis): x = x.asnumpy() out = vm.mean(x, axis) return Tensor(out) return vm_impl @vm_impl_getters.register(P.ReduceMax) def vm_impl_reduce_max(self): """Generate vm_impl function for ReduceMean.""" def vm_impl(x, axis): x = x.asnumpy() if axis == (): axis = None out = np.amax(x, axis) return Tensor(out) return vm_impl @vm_impl_getters.register(P.Equal) def vm_impl_equal(self): """Generate vm_impl function for Equal.""" def vm_impl(x, y): x = x.asnumpy() y = y.asnumpy() out = vm.equal(x, y) return Tensor(np.array(out)) return vm_impl @vm_impl_getters.register(P.NotEqual) def vm_impl_not_equal(self): """Generate vm_impl function for NotEqual.""" def vm_impl(x, y): x = x.asnumpy() y = y.asnumpy() out = vm.not_equal(x, y) return Tensor(np.array(out)) return vm_impl @vm_impl_getters.register(P.Greater) def vm_impl_greater(self): """Generate vm_impl function for Greater.""" def vm_impl(x, y): x = x.asnumpy() y = y.asnumpy() out = vm.greater(x, y) return Tensor(np.array(out)) return vm_impl @vm_impl_getters.register(P.Maximum) def vm_impl_maximum(self): """Generate vm_impl function for Maximum.""" def vm_impl(x, y): x = x.asnumpy() y = y.asnumpy() out = vm.maximum(x, y) return Tensor(out) return vm_impl @vm_impl_getters.register(P.Minimum) def vm_impl_minimum(self): """Generate vm_impl function for Minimum.""" def vm_impl(x, y): x = x.asnumpy() y = y.asnumpy() out = vm.minimum(x, y) return Tensor(out) return vm_impl @vm_impl_getters.register(P.Less) def vm_impl_less(self): """Generate vm_impl function for Less""" def vm_impl(x, y): x = x.asnumpy() y = y.asnumpy() out = vm.less(x, y) return Tensor(np.array(out)) return vm_impl @vm_impl_getters.register(P.ScalarCast) def vm_impl_scalar_cast(self): """Generate vm_impl function for ScalarCast""" def vm_impl(x, t): np_type = dtype_to_nptype(t) value = np_type(x) cast_value = value.item() return cast_value return vm_impl