• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Generate vm_impl function for nn ops"""
16import numpy as np
17
18from mindspore.common.tensor import Tensor
19from mindspore.ops import operations as P
20from mindspore.ops.operations import _grad_ops as G
21from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
22from .vm_interface import vm
23
24
25# pylint: disable=unused-argument
26
27
28@vm_impl_getters.register(P.ScalarSummary)
29def vm_impl_scalar_summary(self):
30    """Generate vm_impl function for ScalarSummary"""
31
32    def vm_impl(string_in, scalar):
33        """Implement by vm mode."""
34        return scalar
35
36    return vm_impl
37
38
39@vm_impl_getters.register(P.ReLU)
40def vm_impl_relu(self):
41    """Generate vm_impl function for ReLU"""
42
43    def vm_impl(x):
44        x = x.asnumpy()
45        output = Tensor(vm.relu(x))
46        return output
47
48    return vm_impl
49
50
51@vm_impl_getters.register(P.Flatten)
52def vm_impl_flatten(self):
53    """Generate vm_impl function for Flatten"""
54
55    def vm_impl(x):
56        x = x.asnumpy()
57        return Tensor(vm.flatten_batch(x))
58
59    return vm_impl
60
61
62@vm_impl_getters.register(P.Softmax)
63def vm_impl_softmax(self):
64    """Generate vm_impl function for Softmax"""
65
66    def vm_impl(x):
67        x = x.asnumpy()
68        return Tensor(vm.softmax(x))
69
70    return vm_impl
71
72
73@vm_impl_getters.register(P.LogSoftmax)
74def vm_impl_log_softmax(self):
75    """Generate vm_impl function for LogSoftmax"""
76
77    def vm_impl(x):
78        x = x.asnumpy()
79        return Tensor(vm.logsoftmax(x))
80
81    return vm_impl
82
83
84@vm_impl_getters.register(P.Tanh)
85def vm_impl_tanh(self):
86    """Generate vm_impl function for Tanh"""
87
88    def vm_impl(x):
89        x = x.asnumpy()
90        return Tensor(vm.tanh(x))
91
92    return vm_impl
93
94
95@vm_impl_getters.register(P.BatchNorm)
96def vm_impl_batch_norm(self):
97    """Generate vm_impl function for BatchNorm"""
98
99    def vm_impl(x, scale, b, mean, variance):
100        # pylint: disable=unused-argument
101        x = x.asnumpy()
102        scale = scale.asnumpy()
103        b = b.asnumpy()
104        mean = mean.asnumpy()
105        variance = variance.asnumpy()
106        out, x_mean, x_var, running_mean, running_var = vm.batch_norm(x, scale, b, mean, \
107                                                                      variance, \
108                                                                      eps=self.epsilon)
109        return Tensor(out), Tensor(x_mean), Tensor(x_var), \
110               Tensor(running_mean), Tensor(running_var)
111
112    return vm_impl
113
114
115@vm_impl_getters.register(P.Conv2D)
116def vm_impl_conv2d(self):
117    """Generate vm_impl function for Conv2D"""
118
119    def vm_impl(x, w):
120        x = x.asnumpy()
121        weight = w.asnumpy()
122        bias = None
123        out = vm.conv2d(x, weight, bias, self.stride, self.pad, self.dilation)
124        return Tensor(out)
125
126    return vm_impl
127
128
129@vm_impl_getters.register(G.MaxPoolGradWithArgmax)
130def vm_impl_max_pool_grad_with_argmax(self):
131    """Generate vm_impl function for MaxPoolGradWithArgmax"""
132
133    def vm_impl(x, dout, argmax):
134        x = x.asnumpy()
135        dout = dout.asnumpy()
136        arg_max = argmax.asnumpy()
137        dx = vm.max_pool_grad_with_argmax(x, dout, arg_max,
138                                          self.kernel_size[1], self.kernel_size[2], self.strides[1])
139        return Tensor(dx)
140
141    return vm_impl
142
143
144@vm_impl_getters.register(P.MaxPoolWithArgmax)
145def vm_impl_max_pool_with_argmax(self):
146    """Generate vm_impl function for MaxPoolWithArgmax"""
147
148    def vm_impl(x):
149        x = x.asnumpy()
150        out, out_argmax = vm.max_pool_with_argmax(x, self.kernel_size[1], self.kernel_size[2], self.strides[1])
151        return Tensor(out), Tensor(out_argmax)
152
153    return vm_impl
154
155
156@vm_impl_getters.register(P.MaxPool)
157def vm_impl_max_pool(self):
158    """Generate vm_impl function for MaxPool"""
159
160    def vm_impl(x):
161        x = x.asnumpy()
162        out = vm.max_pooling(x, self.kernel_size[-2], self.kernel_size[-1], self.strides[-2])
163        return Tensor(out)
164
165    return vm_impl
166
167
168@vm_impl_getters.register(G.MaxPoolGrad)
169def vm_impl_max_pool_grad(self):
170    """Generate vm_impl function for MaxPoolGrad"""
171
172    def vm_impl(x, out, dout):
173        x = x.asnumpy()
174        dout = dout.asnumpy()
175        out = vm.max_pool_grad(x, dout, self.kernel_size[-2], self.kernel_size[-1], self.strides[-2])
176        return Tensor(out)
177
178    return vm_impl
179
180
181@vm_impl_getters.register(P.AvgPool)
182def vm_impl_avg_pool(self):
183    """Generate vm_impl function for AvgPool"""
184
185    def vm_impl(x):
186        x = x.asnumpy()
187        out = vm.avg_pooling(x, self.kernel_size[-2], self.kernel_size[-1], self.strides[-2])
188        return Tensor(out)
189
190    return vm_impl
191
192
193@vm_impl_getters.register(G.AvgPoolGrad)
194def vm_impl_avg_pool_grad(self):
195    """Generate vm_impl function for AvgPoolGrad"""
196
197    def vm_impl(dout, origin_shape):
198        dout = dout.asnumpy()
199        out = vm.avg_pool_grad(dout, origin_shape, self.kernel_size[-2], self.kernel_size[-1], self.strides[-2])
200        return Tensor(out)
201
202    return vm_impl
203
204
205# pylint: disable=function-redefined
206@vm_impl_getters.register(G.BatchNormGrad)
207def vm_impl_fused_batch_norm_grad(self):
208    """Generate vm_impl function for BatchNormGrad"""
209
210    def vm_impl(dy, x, scale, save_mean, save_inv_variance):
211        dy = dy.asnumpy()
212        x = x.asnumpy()
213        scale = scale.asnumpy()
214        save_mean = save_mean.asnumpy()
215        save_inv_variance = save_inv_variance.asnumpy()
216        dx, dscale, dshift = vm.batch_norm_grad(dy, x, scale, save_mean, save_inv_variance)
217        return (Tensor(dx), Tensor(dscale), Tensor(dshift))
218
219    return vm_impl
220
221
222@vm_impl_getters.register(G.ReluGrad)
223def vm_impl_relu_grad(self):
224    """Generate vm_impl function for ReluGrad"""
225
226    def vm_impl(y_backprop, x):
227        x = x.asnumpy()
228        y_backprop = y_backprop.asnumpy()
229        y_backprop = vm.relu_grad(x.copy()) * y_backprop
230        return Tensor(y_backprop)
231
232    return vm_impl
233
234
235@vm_impl_getters.register(P.Conv2DBackpropInput)
236def vm_impl_conv2d_backprop_input(self):
237    """Generate vm_impl function for Conv2DBackpropInput"""
238
239    def vm_impl(dout, w, x_size):
240        dout = dout.asnumpy()
241        w = w.asnumpy()
242        dx = vm.conv2d_backprop_input(dout, x_size, w, self.stride, self.pad)
243        return Tensor(dx)
244
245    return vm_impl
246
247
248@vm_impl_getters.register(G.Conv2DBackpropFilter)
249def vm_impl_conv2d_backprop_filter(self):
250    """Generate vm_impl function for Conv2DBackpropFilter"""
251
252    def vm_impl(dout, x, w_size):
253        x = x.asnumpy()
254        dout = dout.asnumpy()
255        dw = vm.conv2d_backprop_filter(dout, x, w_size, self.stride, self.pad)
256        return Tensor(dw)
257
258    return vm_impl
259
260
261@vm_impl_getters.register(G.FlattenGrad)
262def vm_impl_flatten_grad(self):
263    """Generate vm_impl function for FlattenGrad"""
264
265    def vm_impl(dout, x):
266        dout = dout.asnumpy()
267        dout = vm.flatten_grad(dout, x)
268        return Tensor(dout)
269
270    return vm_impl
271
272
273@vm_impl_getters.register(P.BiasAdd)
274def vm_impl_bias_add(self):
275    """Generate vm_impl function for BiasAdd"""
276
277    def vm_impl(wx, bias):
278        wx = wx.asnumpy()
279        bias = bias.asnumpy()
280        out = wx + bias
281        return Tensor(out)
282
283    return vm_impl
284
285
286@vm_impl_getters.register(G.BiasAddGrad)
287def vm_impl_bias_add_grad(self):
288    """Generate vm_impl function for BiasAddGrad"""
289
290    def vm_impl(dout):
291        dout = dout.asnumpy()
292        shape = np.shape(dout)
293        return Tensor(np.add.reduce(dout, axis=tuple(range(len(shape) - 1))))
294
295    return vm_impl
296
297
298@vm_impl_getters.register(P.SoftmaxCrossEntropyWithLogits)
299def vm_impl_softmax_cross_entropy_with_logits(self):
300    """Generate vm_impl function for SoftmaxCrossEntropyWithLogits"""
301
302    def vm_impl(logits, labels):
303        logits = logits.asnumpy()
304        labels = labels.asnumpy()
305        loss, dx = vm.softmax_cross_entropy_with_logits(logits, labels)
306        return (Tensor(np.array(loss)), Tensor(dx))
307
308    return vm_impl
309
310
311@vm_impl_getters.register(P.SparseSoftmaxCrossEntropyWithLogits)
312def vm_impl_sparse_softmax_cross_entropy_with_logits(self):
313    """Generate vm_impl function for SparseSoftmaxCrossEntropyWithLogits"""
314
315    def vm_impl(logits, labels):
316        logits = logits.asnumpy()
317        labels = labels.asnumpy()
318
319        n_class = labels.max() + 1
320        n_sample = labels.shape[0]
321        one_hot_label = np.zeros((n_sample, n_class))  # 3个样本,4个类别
322        one_hot_label[:, labels] = 1  # 非零列赋值为1
323        loss, dx = vm.softmax_cross_entropy_with_logits(logits, one_hot_label)
324        if self.is_grad:
325            return (Tensor(dx),)
326        return (Tensor(np.array(loss)),)
327
328    return vm_impl
329
330
331@vm_impl_getters.register(P.ApplyMomentum)
332def vm_impl_momentum(self):
333    """Generate vm_impl function for Momentum"""
334
335    def vm_impl(variable,
336                accumulation,
337                learning_rate,
338                gradient,
339                momentum,
340                use_nesterov=False):
341        gradient = gradient.asnumpy()
342        accumulation = accumulation.asnumpy()
343        variable = variable.asnumpy()
344        shape = accumulation.shape
345        learning_rate = np.full(shape, learning_rate.asnumpy())
346        momentum = np.full(shape, momentum.asnumpy())
347        accumulation = accumulation * momentum + gradient
348        if use_nesterov is True:
349            variable -= gradient * learning_rate + accumulation * momentum * learning_rate
350        else:
351            variable -= accumulation * learning_rate
352        return Tensor(variable)
353
354    return vm_impl
355
356
357@vm_impl_getters.register(P.ResizeBilinear)
358def vm_impl_resize_bilinear(self):
359    """Generate vm_impl function for ResizeBilinear"""
360
361    def vm_impl(x):
362        out = vm.ResizeBilinear(x)
363        return Tensor(out)
364
365    return vm_impl
366
367
368@vm_impl_getters.register(G.ResizeBilinearGrad)
369def vm_impl_resize_bilinear_grad(self):
370    """Generate vm_impl function for ResizeBilinearGrad"""
371
372    def vm_impl(dout, original_image):
373        out = vm.ResizeBilinearGrad(dout, original_image)
374        return Tensor(out)
375
376    return vm_impl
377