1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Generate vm_impl function for nn ops""" 16import numpy as np 17 18from mindspore.common.tensor import Tensor 19from mindspore.ops import operations as P 20from mindspore.ops.operations import _grad_ops as G 21from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters 22from .vm_interface import vm 23 24 25# pylint: disable=unused-argument 26 27 28@vm_impl_getters.register(P.ScalarSummary) 29def vm_impl_scalar_summary(self): 30 """Generate vm_impl function for ScalarSummary""" 31 32 def vm_impl(string_in, scalar): 33 """Implement by vm mode.""" 34 return scalar 35 36 return vm_impl 37 38 39@vm_impl_getters.register(P.ReLU) 40def vm_impl_relu(self): 41 """Generate vm_impl function for ReLU""" 42 43 def vm_impl(x): 44 x = x.asnumpy() 45 output = Tensor(vm.relu(x)) 46 return output 47 48 return vm_impl 49 50 51@vm_impl_getters.register(P.Flatten) 52def vm_impl_flatten(self): 53 """Generate vm_impl function for Flatten""" 54 55 def vm_impl(x): 56 x = x.asnumpy() 57 return Tensor(vm.flatten_batch(x)) 58 59 return vm_impl 60 61 62@vm_impl_getters.register(P.Softmax) 63def vm_impl_softmax(self): 64 """Generate vm_impl function for Softmax""" 65 66 def vm_impl(x): 67 x = x.asnumpy() 68 return Tensor(vm.softmax(x)) 69 70 return vm_impl 71 72 73@vm_impl_getters.register(P.LogSoftmax) 74def vm_impl_log_softmax(self): 75 """Generate vm_impl function for LogSoftmax""" 76 77 def vm_impl(x): 78 x = x.asnumpy() 79 return Tensor(vm.logsoftmax(x)) 80 81 return vm_impl 82 83 84@vm_impl_getters.register(P.Tanh) 85def vm_impl_tanh(self): 86 """Generate vm_impl function for Tanh""" 87 88 def vm_impl(x): 89 x = x.asnumpy() 90 return Tensor(vm.tanh(x)) 91 92 return vm_impl 93 94 95@vm_impl_getters.register(P.BatchNorm) 96def vm_impl_batch_norm(self): 97 """Generate vm_impl function for BatchNorm""" 98 99 def vm_impl(x, scale, b, mean, variance): 100 # pylint: disable=unused-argument 101 x = x.asnumpy() 102 scale = scale.asnumpy() 103 b = b.asnumpy() 104 mean = mean.asnumpy() 105 variance = variance.asnumpy() 106 out, x_mean, x_var, running_mean, running_var = vm.batch_norm(x, scale, b, mean, \ 107 variance, \ 108 eps=self.epsilon) 109 return Tensor(out), Tensor(x_mean), Tensor(x_var), \ 110 Tensor(running_mean), Tensor(running_var) 111 112 return vm_impl 113 114 115@vm_impl_getters.register(P.Conv2D) 116def vm_impl_conv2d(self): 117 """Generate vm_impl function for Conv2D""" 118 119 def vm_impl(x, w): 120 x = x.asnumpy() 121 weight = w.asnumpy() 122 bias = None 123 out = vm.conv2d(x, weight, bias, self.stride, self.pad, self.dilation) 124 return Tensor(out) 125 126 return vm_impl 127 128 129@vm_impl_getters.register(G.MaxPoolGradWithArgmax) 130def vm_impl_max_pool_grad_with_argmax(self): 131 """Generate vm_impl function for MaxPoolGradWithArgmax""" 132 133 def vm_impl(x, dout, argmax): 134 x = x.asnumpy() 135 dout = dout.asnumpy() 136 arg_max = argmax.asnumpy() 137 dx = vm.max_pool_grad_with_argmax(x, dout, arg_max, 138 self.kernel_size[1], self.kernel_size[2], self.strides[1]) 139 return Tensor(dx) 140 141 return vm_impl 142 143 144@vm_impl_getters.register(P.MaxPoolWithArgmax) 145def vm_impl_max_pool_with_argmax(self): 146 """Generate vm_impl function for MaxPoolWithArgmax""" 147 148 def vm_impl(x): 149 x = x.asnumpy() 150 out, out_argmax = vm.max_pool_with_argmax(x, self.kernel_size[1], self.kernel_size[2], self.strides[1]) 151 return Tensor(out), Tensor(out_argmax) 152 153 return vm_impl 154 155 156@vm_impl_getters.register(P.MaxPool) 157def vm_impl_max_pool(self): 158 """Generate vm_impl function for MaxPool""" 159 160 def vm_impl(x): 161 x = x.asnumpy() 162 out = vm.max_pooling(x, self.kernel_size[-2], self.kernel_size[-1], self.strides[-2]) 163 return Tensor(out) 164 165 return vm_impl 166 167 168@vm_impl_getters.register(G.MaxPoolGrad) 169def vm_impl_max_pool_grad(self): 170 """Generate vm_impl function for MaxPoolGrad""" 171 172 def vm_impl(x, out, dout): 173 x = x.asnumpy() 174 dout = dout.asnumpy() 175 out = vm.max_pool_grad(x, dout, self.kernel_size[-2], self.kernel_size[-1], self.strides[-2]) 176 return Tensor(out) 177 178 return vm_impl 179 180 181@vm_impl_getters.register(P.AvgPool) 182def vm_impl_avg_pool(self): 183 """Generate vm_impl function for AvgPool""" 184 185 def vm_impl(x): 186 x = x.asnumpy() 187 out = vm.avg_pooling(x, self.kernel_size[-2], self.kernel_size[-1], self.strides[-2]) 188 return Tensor(out) 189 190 return vm_impl 191 192 193@vm_impl_getters.register(G.AvgPoolGrad) 194def vm_impl_avg_pool_grad(self): 195 """Generate vm_impl function for AvgPoolGrad""" 196 197 def vm_impl(dout, origin_shape): 198 dout = dout.asnumpy() 199 out = vm.avg_pool_grad(dout, origin_shape, self.kernel_size[-2], self.kernel_size[-1], self.strides[-2]) 200 return Tensor(out) 201 202 return vm_impl 203 204 205# pylint: disable=function-redefined 206@vm_impl_getters.register(G.BatchNormGrad) 207def vm_impl_fused_batch_norm_grad(self): 208 """Generate vm_impl function for BatchNormGrad""" 209 210 def vm_impl(dy, x, scale, save_mean, save_inv_variance): 211 dy = dy.asnumpy() 212 x = x.asnumpy() 213 scale = scale.asnumpy() 214 save_mean = save_mean.asnumpy() 215 save_inv_variance = save_inv_variance.asnumpy() 216 dx, dscale, dshift = vm.batch_norm_grad(dy, x, scale, save_mean, save_inv_variance) 217 return (Tensor(dx), Tensor(dscale), Tensor(dshift)) 218 219 return vm_impl 220 221 222@vm_impl_getters.register(G.ReluGrad) 223def vm_impl_relu_grad(self): 224 """Generate vm_impl function for ReluGrad""" 225 226 def vm_impl(y_backprop, x): 227 x = x.asnumpy() 228 y_backprop = y_backprop.asnumpy() 229 y_backprop = vm.relu_grad(x.copy()) * y_backprop 230 return Tensor(y_backprop) 231 232 return vm_impl 233 234 235@vm_impl_getters.register(P.Conv2DBackpropInput) 236def vm_impl_conv2d_backprop_input(self): 237 """Generate vm_impl function for Conv2DBackpropInput""" 238 239 def vm_impl(dout, w, x_size): 240 dout = dout.asnumpy() 241 w = w.asnumpy() 242 dx = vm.conv2d_backprop_input(dout, x_size, w, self.stride, self.pad) 243 return Tensor(dx) 244 245 return vm_impl 246 247 248@vm_impl_getters.register(G.Conv2DBackpropFilter) 249def vm_impl_conv2d_backprop_filter(self): 250 """Generate vm_impl function for Conv2DBackpropFilter""" 251 252 def vm_impl(dout, x, w_size): 253 x = x.asnumpy() 254 dout = dout.asnumpy() 255 dw = vm.conv2d_backprop_filter(dout, x, w_size, self.stride, self.pad) 256 return Tensor(dw) 257 258 return vm_impl 259 260 261@vm_impl_getters.register(G.FlattenGrad) 262def vm_impl_flatten_grad(self): 263 """Generate vm_impl function for FlattenGrad""" 264 265 def vm_impl(dout, x): 266 dout = dout.asnumpy() 267 dout = vm.flatten_grad(dout, x) 268 return Tensor(dout) 269 270 return vm_impl 271 272 273@vm_impl_getters.register(P.BiasAdd) 274def vm_impl_bias_add(self): 275 """Generate vm_impl function for BiasAdd""" 276 277 def vm_impl(wx, bias): 278 wx = wx.asnumpy() 279 bias = bias.asnumpy() 280 out = wx + bias 281 return Tensor(out) 282 283 return vm_impl 284 285 286@vm_impl_getters.register(G.BiasAddGrad) 287def vm_impl_bias_add_grad(self): 288 """Generate vm_impl function for BiasAddGrad""" 289 290 def vm_impl(dout): 291 dout = dout.asnumpy() 292 shape = np.shape(dout) 293 return Tensor(np.add.reduce(dout, axis=tuple(range(len(shape) - 1)))) 294 295 return vm_impl 296 297 298@vm_impl_getters.register(P.SoftmaxCrossEntropyWithLogits) 299def vm_impl_softmax_cross_entropy_with_logits(self): 300 """Generate vm_impl function for SoftmaxCrossEntropyWithLogits""" 301 302 def vm_impl(logits, labels): 303 logits = logits.asnumpy() 304 labels = labels.asnumpy() 305 loss, dx = vm.softmax_cross_entropy_with_logits(logits, labels) 306 return (Tensor(np.array(loss)), Tensor(dx)) 307 308 return vm_impl 309 310 311@vm_impl_getters.register(P.SparseSoftmaxCrossEntropyWithLogits) 312def vm_impl_sparse_softmax_cross_entropy_with_logits(self): 313 """Generate vm_impl function for SparseSoftmaxCrossEntropyWithLogits""" 314 315 def vm_impl(logits, labels): 316 logits = logits.asnumpy() 317 labels = labels.asnumpy() 318 319 n_class = labels.max() + 1 320 n_sample = labels.shape[0] 321 one_hot_label = np.zeros((n_sample, n_class)) # 3个样本,4个类别 322 one_hot_label[:, labels] = 1 # 非零列赋值为1 323 loss, dx = vm.softmax_cross_entropy_with_logits(logits, one_hot_label) 324 if self.is_grad: 325 return (Tensor(dx),) 326 return (Tensor(np.array(loss)),) 327 328 return vm_impl 329 330 331@vm_impl_getters.register(P.ApplyMomentum) 332def vm_impl_momentum(self): 333 """Generate vm_impl function for Momentum""" 334 335 def vm_impl(variable, 336 accumulation, 337 learning_rate, 338 gradient, 339 momentum, 340 use_nesterov=False): 341 gradient = gradient.asnumpy() 342 accumulation = accumulation.asnumpy() 343 variable = variable.asnumpy() 344 shape = accumulation.shape 345 learning_rate = np.full(shape, learning_rate.asnumpy()) 346 momentum = np.full(shape, momentum.asnumpy()) 347 accumulation = accumulation * momentum + gradient 348 if use_nesterov is True: 349 variable -= gradient * learning_rate + accumulation * momentum * learning_rate 350 else: 351 variable -= accumulation * learning_rate 352 return Tensor(variable) 353 354 return vm_impl 355 356 357@vm_impl_getters.register(P.ResizeBilinear) 358def vm_impl_resize_bilinear(self): 359 """Generate vm_impl function for ResizeBilinear""" 360 361 def vm_impl(x): 362 out = vm.ResizeBilinear(x) 363 return Tensor(out) 364 365 return vm_impl 366 367 368@vm_impl_getters.register(G.ResizeBilinearGrad) 369def vm_impl_resize_bilinear_grad(self): 370 """Generate vm_impl function for ResizeBilinearGrad""" 371 372 def vm_impl(dout, original_image): 373 out = vm.ResizeBilinearGrad(dout, original_image) 374 return Tensor(out) 375 376 return vm_impl 377