• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Generate vm_impl function for math ops"""
16import copy
17import numpy as np
18
19from mindspore.common.dtype import dtype_to_nptype
20from mindspore.common.tensor import Tensor
21from mindspore.ops import operations as P
22from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
23from .vm_interface import vm
24
25
26# pylint: disable=unused-argument
27
28
29@vm_impl_getters.register(P.Add)
30def vm_impl_tensor_add(self):
31    """Generate vm_impl function for TensorAdd."""
32
33    def vm_impl(x, y):
34        x = x.asnumpy()
35        y = y.asnumpy()
36        return Tensor(x + y)
37
38    return vm_impl
39
40
41# pylint: disable=used-before-assignment
42@vm_impl_getters.register(P.LogicalNot)
43def vm_impl_logical_not(self):
44    def vm_impl(x):
45        x = x.asnumpy()
46        out = vm.logical_not(x)
47        return Tensor(out)
48
49    return vm_impl
50
51@vm_impl_getters.register(P.MatMul)
52def vm_impl_mat_mul(self):
53    """Generate vm_impl function for MatMul."""
54
55    def vm_impl(x, w):
56        x = x.asnumpy()
57        w = w.asnumpy()
58        if self.transpose_a:
59            x = x.transpose()
60        if self.transpose_b:
61            w = w.transpose()
62        z = x @ w
63        return Tensor(z)
64
65    return vm_impl
66
67
68@vm_impl_getters.register(P.AddN)
69def vm_impl_addn(self):
70    """Generate vm_impl function for AddN."""
71
72    def vm_impl(inputs):
73        added = copy.deepcopy(inputs[0].asnumpy())
74        for x in inputs[1:]:
75            added += x.asnumpy()
76        return Tensor(added)
77
78    return vm_impl
79
80
81@vm_impl_getters.register(P.Neg)
82def vm_impl_neg(self):
83    """Generate vm_impl function for Neg."""
84
85    def vm_impl(x):
86        x = x.asnumpy()
87        return Tensor(-x)
88
89    return vm_impl
90
91
92@vm_impl_getters.register(P.Sub)
93def vm_impl_Sub(self):
94    """Generate vm_impl function for Sub."""
95
96    def vm_impl(x, y):
97        x = x.asnumpy()
98        y = y.asnumpy()
99        return Tensor(x - y)
100
101    return vm_impl
102
103
104@vm_impl_getters.register(P.Mul)
105def vm_impl_mul(self):
106    """Generate vm_impl function for Mul."""
107
108    def vm_impl(x, y):
109        x = x.asnumpy()
110        y = y.asnumpy()
111        return Tensor(x * y)
112
113    return vm_impl
114
115
116@vm_impl_getters.register(P.Square)
117def vm_impl_square(self):
118    """Generate vm_impl function for Square."""
119
120    def vm_impl(x):
121        x = x.asnumpy()
122        return Tensor(x * x)
123
124    return vm_impl
125
126
127@vm_impl_getters.register(P.Sqrt)
128def vm_impl_sqrt(self):
129    """Generate vm_impl function for Sqrt."""
130
131    def vm_impl(x):
132        x = x.asnumpy()
133        res = vm.sqrt(x)
134        return Tensor(res)
135
136    return vm_impl
137
138
139@vm_impl_getters.register(P.Pow)
140def vm_impl_pow(self):
141    """Generate vm_impl function for Pow."""
142
143    def vm_impl(x, y):
144        x = x.asnumpy()
145        y = y.asnumpy()
146        res = vm.power(x, y)
147        return Tensor(res)
148
149    return vm_impl
150
151
152@vm_impl_getters.register(P.Exp)
153def vm_impl_exp(self):
154    """Generate vm_impl function for Exp."""
155
156    def vm_impl(x):
157        x = x.asnumpy()
158        res = vm.exp(x)
159        return Tensor(res)
160
161    return vm_impl
162
163
164@vm_impl_getters.register(P.RealDiv)
165def vm_impl_real_div(self):
166    """Generate vm_impl function for RealDiv."""
167
168    def vm_impl(x, y):
169        x = x.asnumpy()
170        y = y.asnumpy()
171        out = x / y
172        out = np.array(out, x.dtype)
173        return Tensor(out)
174
175    return vm_impl
176
177
178@vm_impl_getters.register(P.Div)
179def vm_impl_div(self):
180    """Generate vm_impl function for Div."""
181
182    def vm_impl(x, y):
183        x = x.asnumpy()
184        y = y.asnumpy()
185        return Tensor(x / y)
186
187    return vm_impl
188
189
190@vm_impl_getters.register(P.ReduceMean)
191def vm_impl_reduce_mean(self):
192    """Generate vm_impl function for ReduceMean."""
193
194    def vm_impl(x, axis):
195        x = x.asnumpy()
196        out = vm.mean(x, axis)
197        return Tensor(out)
198
199    return vm_impl
200
201@vm_impl_getters.register(P.ReduceMax)
202def vm_impl_reduce_max(self):
203    """Generate vm_impl function for ReduceMean."""
204
205    def vm_impl(x, axis):
206        x = x.asnumpy()
207        if axis == ():
208            axis = None
209        out = np.amax(x, axis)
210        return Tensor(out)
211
212    return vm_impl
213
214@vm_impl_getters.register(P.Equal)
215def vm_impl_equal(self):
216    """Generate vm_impl function for Equal."""
217
218    def vm_impl(x, y):
219        x = x.asnumpy()
220        y = y.asnumpy()
221        out = vm.equal(x, y)
222        return Tensor(np.array(out))
223
224    return vm_impl
225
226
227@vm_impl_getters.register(P.NotEqual)
228def vm_impl_not_equal(self):
229    """Generate vm_impl function for NotEqual."""
230
231    def vm_impl(x, y):
232        x = x.asnumpy()
233        y = y.asnumpy()
234        out = vm.not_equal(x, y)
235        return Tensor(np.array(out))
236
237    return vm_impl
238
239
240@vm_impl_getters.register(P.Greater)
241def vm_impl_greater(self):
242    """Generate vm_impl function for Greater."""
243
244    def vm_impl(x, y):
245        x = x.asnumpy()
246        y = y.asnumpy()
247        out = vm.greater(x, y)
248        return Tensor(np.array(out))
249
250    return vm_impl
251
252
253@vm_impl_getters.register(P.Maximum)
254def vm_impl_maximum(self):
255    """Generate vm_impl function for Maximum."""
256
257    def vm_impl(x, y):
258        x = x.asnumpy()
259        y = y.asnumpy()
260        out = vm.maximum(x, y)
261        return Tensor(out)
262
263    return vm_impl
264
265
266@vm_impl_getters.register(P.Minimum)
267def vm_impl_minimum(self):
268    """Generate vm_impl function for Minimum."""
269
270    def vm_impl(x, y):
271        x = x.asnumpy()
272        y = y.asnumpy()
273        out = vm.minimum(x, y)
274        return Tensor(out)
275
276    return vm_impl
277
278
279@vm_impl_getters.register(P.Less)
280def vm_impl_less(self):
281    """Generate vm_impl function for Less"""
282
283    def vm_impl(x, y):
284        x = x.asnumpy()
285        y = y.asnumpy()
286        out = vm.less(x, y)
287        return Tensor(np.array(out))
288
289    return vm_impl
290
291
292@vm_impl_getters.register(P.ScalarCast)
293def vm_impl_scalar_cast(self):
294    """Generate vm_impl function for ScalarCast"""
295
296    def vm_impl(x, t):
297        np_type = dtype_to_nptype(t)
298        value = np_type(x)
299        cast_value = value.item()
300        return cast_value
301
302    return vm_impl
303