• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Generate vm_impl function for array ops"""
16import numpy as np
17import mindspore.common.dtype as mstype
18from mindspore.common.tensor import Tensor
19from mindspore.ops import operations as P
20from mindspore.ops.operations import _grad_ops as G
21from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
22from .vm_interface import vm
23
24# pylint: disable=unused-argument
25@vm_impl_getters.register(P.Assign)
26def vm_impl_assign(self):
27    """Generate vm_impl function for Assign"""
28    def vm_impl(x, value, u=None):
29        x.assign_value(value)
30        return x
31    return vm_impl
32
33@vm_impl_getters.register(P.ExpandDims)
34def vm_impl_expand_dims(self):
35    """Generate vm_impl function for ExpandDims"""
36
37    def vm_impl(x, axis):
38        if isinstance(x, float):
39            x = Tensor(np.array([x]))
40        x = x.asnumpy()
41        out = vm.expand_dims(x, axis)
42        return Tensor(out)
43
44    return vm_impl
45
46
47@vm_impl_getters.register(P.DType)
48def vm_impl_dType(self):
49    """Generate vm_impl function for DType"""
50
51    def vm_impl(x):
52        # update the src type
53        return x.dtype
54
55    return vm_impl
56
57
58@vm_impl_getters.register(P.Cast)
59def vm_impl_cast(self):
60    """Generate vm_impl function for Cast"""
61
62    def vm_impl(x, t):
63        if isinstance(t, type(mstype.tensor)):
64            t = t.element_type()
65        # update the src type
66        x = x.asnumpy()
67        out = x.astype(mstype.dtype_to_nptype(t))
68        return Tensor(out)
69
70    return vm_impl
71
72
73@vm_impl_getters.register(P.Reshape)
74def vm_impl_reshape(self):
75    """Generate vm_impl function for Reshape"""
76
77    def vm_impl(x, shp):
78        x = x.asnumpy()
79        out = vm.reshape(x, shp)
80        return Tensor(out)
81
82    return vm_impl
83
84
85@vm_impl_getters.register(P.Shape)
86def vm_impl_shape(self):
87    """Generate vm_impl function for Shape"""
88
89    def vm_impl(x):
90        shp = vm.shape(x.asnumpy())
91        return shp
92
93    return vm_impl
94
95
96@vm_impl_getters.register(P.Squeeze)
97def vm_impl_squeeze(self):
98    """Generate vm_impl function for Squeeze"""
99
100    def vm_impl(x):
101        x = x.asnumpy()
102        out = vm.squeeze(x, self.axis)
103        return Tensor(out)
104
105    return vm_impl
106
107
108@vm_impl_getters.register(P.Transpose)
109def vm_impl_transpose(self):
110    """Generate vm_impl function for Transpose"""
111
112    def vm_impl(x, perm=None):
113        x = x.asnumpy()
114        if perm is None:
115            perm = [i for i in reversed(range(len(x.shape)))]
116        out = vm.transpose(x, perm)
117        return Tensor(out)
118
119    return vm_impl
120
121
122@vm_impl_getters.register(P.Split)
123def vm_impl_split(self):
124    """Generate vm_impl function for Split"""
125
126    def vm_impl(x):
127        x = x.asnumpy()
128        output = np.array_split(x, (self.pos,))
129        return Tensor(output[0]), Tensor(output[1])
130
131    return vm_impl
132
133
134@vm_impl_getters.register(P.Fill)
135def vm_impl_fill(self):
136    """Generate vm_impl function for Fill"""
137
138    def vm_impl(dims, x):
139        if isinstance(x, int):
140            ret = np.full(dims, x, np.int32)
141        else:
142            ret = np.full(dims, x, np.float32)
143        return Tensor(ret)
144
145    return vm_impl
146
147
148@vm_impl_getters.register(P.Eye)
149def vm_impl_eye(self):
150    """Generate vm_impl function for Eye"""
151
152    def vm_impl(n, m, t):
153        np_type = mstype.dtype_to_nptype(t)
154        ret = np.eye(n, m, dtype=np_type)
155        return Tensor(ret)
156
157    return vm_impl
158
159
160@vm_impl_getters.register(P.InvertPermutation)
161def vm_impl_invert_permutation(self):
162    """Generate vm_impl function for InvertPermutation"""
163
164    def vm_impl(x):
165        out = vm.invert_permutation(x)
166        return out
167
168    return vm_impl
169
170
171@vm_impl_getters.register(P.Argmax)
172def vm_impl_argmax(self):
173    """Generate vm_impl function for Argmax"""
174
175    def vm_impl(x):
176        output = np.argmax(x.asnumpy(), axis=self.axis)
177        return Tensor(output.ravel())
178
179    return vm_impl
180
181
182@vm_impl_getters.register(P.Tile)
183def vm_impl_tile(self):
184    """Generate vm_impl function for Tile"""
185
186    def vm_impl(x, multiples):
187        x = x.asnumpy()
188        out = np.tile(x, multiples)
189        return Tensor(out)
190
191    return vm_impl
192
193
194@vm_impl_getters.register(P.ReduceAll)
195def vm_impl_all(self):
196    """Generate vm_impl function for All"""
197
198    def vm_impl(x, axis):
199        x = x.asnumpy()
200        out = vm.all(x, axis, self.keep_dims)
201        return Tensor(out)
202
203    return vm_impl
204
205
206@vm_impl_getters.register(P.ReduceAny)
207def vm_impl_any(self):
208    """Generate vm_impl function for Any"""
209
210    def vm_impl(x, axis):
211        x = x.asnumpy()
212        out = vm.any(x, axis, self.keep_dims)
213        return Tensor(out)
214
215    return vm_impl
216
217
218@vm_impl_getters.register(P.Concat)
219def vm_impl_concatV2(self):
220    """Generate vm_impl function for Concat"""
221
222    def vm_impl(x):
223        x = x.asnumpy()
224        out = vm.Concat(x, self.axis)
225        return Tensor(out)
226
227    return vm_impl
228
229
230@vm_impl_getters.register(P.Slice)
231def vm_impl_slice(self):
232    """Generate vm_impl function for Slice"""
233
234    def vm_impl(x, begin, size):
235        x = x.asnumpy()
236        begin = begin.asnumpy()
237        size = size.asnumpy()
238        out = vm.Slice(x, begin, size)
239        return Tensor(out)
240
241    return vm_impl
242
243
244@vm_impl_getters.register(G.ConcatOffset)
245def vm_impl_concatOffset(self):
246    """Generate vm_impl function for ConcatOffset"""
247
248    def vm_impl(x):
249        out = vm.ConcatOffset(x)  # out is tuple
250        return out
251
252    return vm_impl
253
254
255@vm_impl_getters.register(P.ReduceSum)
256def vm_impl_sum(self):
257    """Generate vm_impl function for Sum"""
258
259    def vm_impl(x, axis):
260        x = x.asnumpy()
261        if axis == ():
262            out = np.sum(x)
263        else:
264            out = np.sum(x, axis=axis)
265        return Tensor(np.array(out))
266
267    return vm_impl
268
269
270@vm_impl_getters.register(P.Select)
271def vm_impl_select(self):
272    """Generate vm_impl function for Select"""
273
274    def vm_impl(cond, x, y):
275        """
276        Args:
277            cond: A `Tensor` of type `bool`
278            x: A Tensor which may have the same shape as `condition`.
279            y: A `Tensor` with the same shape and type as `x`.
280        """
281        cond = cond.asnumpy()
282        x = x.asnumpy()
283        y = y.asnumpy()
284        out = vm.select(cond, x, y)
285        return Tensor(out)
286
287    return vm_impl
288
289
290@vm_impl_getters.register(P.Square)
291def vm_impl_square(self):
292    """Generate vm_impl function for Square"""
293
294    def vm_impl(x):
295        x = x.asnumpy()
296        return Tensor(x * x)
297
298    return vm_impl
299
300
301@vm_impl_getters.register(P.ZerosLike)
302def vm_impl_zeros_like(self):
303    """Generate vm_impl function for ZerosLike"""
304    def vm_impl(x):
305        return Tensor(np.zeros_like(x.asnumpy()))
306
307
308@vm_impl_getters.register(P.Partial)
309def vm_impl_partial(self):
310    """Generate vm_impl function for Partial"""
311    def vm_impl(*args):
312        func = args[0].__call__
313        partial_func = functools.partial(func, *args[1:])
314        return partial_func
315
316    return vm_impl
317
318
319@vm_impl_getters.register(P.Depend)
320def vm_impl_depend(self):
321    """Generate vm_impl function for Depend"""
322    def vm_impl(value, expr):
323        return value
324
325    return vm_impl
326
327
328@vm_impl_getters.register(P.UpdateState)
329def vm_impl_updatestate(self):
330    """Generate vm_impl function for UpdateState"""
331    def vm_impl(monad, expr):
332        return monad
333
334    return vm_impl
335
336
337@vm_impl_getters.register(P.Load)
338def vm_impl_load(self):
339    """Generate vm_impl function for Load"""
340    def vm_impl(value, u=None):
341        return value
342
343    return vm_impl
344